diff --git a/CHANGELOG.md b/CHANGELOG.md
index f7fecbd..831a07d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,10 +1,24 @@
# Change Log
+## v0.3.1
+
+### Added
+
+* A [**`pyproject.toml`**
+file](https://github.com/sgrvinod/chess-transformers/blob/main/pyproject.toml) has been added in compliance with [PEP 660](https://peps.python.org/pep-0660/). While the inclusion of a `setup.py` file is not deprecated, its use as a command-line tool, such as in the legacy `setup.py develop` method for performing an editable installation is now deprecated.
+
+### Changed
+
+* **`chess_transformers.train.datasets.ChessDataset`** was optimized for large datasets. A list of indices for the data split is no longer maintained or indexed in the dataset.
+* The `TRAINING_CHECKPOINT` parameter in each of **`chess_transformers.configs.models`** was set to `None` to reflect the correct conditions for beginning training of a model.
+* Dynamic shape tracing is disabled for the compilation of [*CT-ED-45*](https://github.com/sgrvinod/chess-transformers#ct-ed-45) to prevent memory leaks as seen in [#16](https://github.com/sgrvinod/chess-transformers/issues/16).
+* References to `torch.cuda.amp.GradScaler(...)` have been replaced by `torch.amp.GradScaler(device="cuda", ...)` following its deprecation.
+
## v0.3.0
### Added
-* There are 3 new datasets: [ML23c](https://github.com/sgrvinod/chess-transformers#ml23c), [GC22c](https://github.com/sgrvinod/chess-transformers#gc22c), and [ML23d](https://github.com/sgrvinod/chess-transformers#ml23d).
+* There are 3 new datasets: [*ML23c*](https://github.com/sgrvinod/chess-transformers#ml23c), [*GC22c*](https://github.com/sgrvinod/chess-transformers#gc22c), and [*ML23d*](https://github.com/sgrvinod/chess-transformers#ml23d).
* A new naming convention for datasets is used. Datasets are now named in the format "[*PGN Fileset*][*Filters*]". For example, *LE1222* is now called [*LE22ct*](https://github.com/sgrvinod/chess-transformers#le22ct), where *LE22* is the name of the PGN fileset from which this dataset was derived, and "*c*", "*t*" are filters for games that ended in checkmates and games that used a specific time control respectively.
* [*CT-EFT-85*](https://github.com/sgrvinod/chess-transformers#ct-eft-85) is a new trained model with about 85 million parameters.
* **`chess_transformers.train.utils.get_lr()`** now accepts new arguments, `schedule` and `decay`, to accomodate a new learning rate schedule: exponential decay after warmup.
diff --git a/README.md b/README.md
index e8e2c6e..56e84b9 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
Chess Transformers
Teaching transformers to play chess
-
+
*Chess Transformers* is a library for training transformer models to play chess by learning from human games.
diff --git a/chess_transformers/configs/models/CT-E-20.py b/chess_transformers/configs/models/CT-E-20.py
index b003bfd..1694f0f 100644
--- a/chess_transformers/configs/models/CT-E-20.py
+++ b/chess_transformers/configs/models/CT-E-20.py
@@ -94,8 +94,8 @@
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
) # folder containing checkpoints
TRAINING_CHECKPOINT = (
- NAME + ".pt"
-) # path to model checkpoint to resume training, None if none
+ None # path to model checkpoint (NAME + ".pt") to resume training, None if none
+)
CHECKPOINT_AVG_PREFIX = (
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
)
diff --git a/chess_transformers/configs/models/CT-ED-45.py b/chess_transformers/configs/models/CT-ED-45.py
index d61cf64..81dbcb5 100644
--- a/chess_transformers/configs/models/CT-ED-45.py
+++ b/chess_transformers/configs/models/CT-ED-45.py
@@ -50,7 +50,7 @@
N_MOVES = 10 # expected maximum length of move sequences in the model, <= MAX_MOVE_SEQUENCE_LENGTH
DISABLE_COMPILATION = False # disable model compilation?
COMPILATION_MODE = "default" # mode of model compilation (see torch.compile())
-DYNAMIC_COMPILATION = True # expect tensors with dynamic shapes?
+DYNAMIC_COMPILATION = False # expect tensors with dynamic shapes?
SAMPLING_K = 1 # k in top-k sampling model predictions during play
MODEL = ChessTransformer # custom PyTorch model to train
@@ -94,8 +94,8 @@
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
) # folder containing checkpoints
TRAINING_CHECKPOINT = (
- NAME + ".pt"
-) # path to model checkpoint to resume training, None if none
+ None # path to model checkpoint (NAME + ".pt") to resume training, None if none
+)
CHECKPOINT_AVG_PREFIX = (
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
)
diff --git a/chess_transformers/configs/models/CT-EFT-20.py b/chess_transformers/configs/models/CT-EFT-20.py
index 54fb64f..4f2f527 100644
--- a/chess_transformers/configs/models/CT-EFT-20.py
+++ b/chess_transformers/configs/models/CT-EFT-20.py
@@ -94,8 +94,8 @@
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
) # folder containing checkpoints
TRAINING_CHECKPOINT = (
- NAME + ".pt"
-) # path to model checkpoint to resume training, None if none
+ None # path to model checkpoint (NAME + ".pt") to resume training, None if none
+)
CHECKPOINT_AVG_PREFIX = (
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
)
diff --git a/chess_transformers/configs/models/CT-EFT-85.py b/chess_transformers/configs/models/CT-EFT-85.py
index 93a3350..32a6850 100644
--- a/chess_transformers/configs/models/CT-EFT-85.py
+++ b/chess_transformers/configs/models/CT-EFT-85.py
@@ -93,7 +93,9 @@
CHECKPOINT_FOLDER = str(
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
) # folder containing checkpoints
-TRAINING_CHECKPOINT = None # path to model checkpoint to resume training, None if none
+TRAINING_CHECKPOINT = (
+ None # path to model checkpoint (NAME + ".pt") to resume training, None if none
+)
AVERAGE_STEPS = {491000, 492500, 494000, 495500, 497000, 498500, 500000}
CHECKPOINT_AVG_PREFIX = (
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
diff --git a/chess_transformers/train/datasets.py b/chess_transformers/train/datasets.py
index 63a63f8..e2ab8eb 100644
--- a/chess_transformers/train/datasets.py
+++ b/chess_transformers/train/datasets.py
@@ -34,19 +34,15 @@ def __init__(self, data_folder, h5_file, split, n_moves=None, **unused):
# Open table in H5 file
self.h5_file = tb.open_file(os.path.join(data_folder, h5_file), mode="r")
self.encoded_table = self.h5_file.root.encoded_data
+ self.split = split
# Create indices
- # TODO: optimize by using a start_index and not a list of indices
if split == "train":
- self.indices = list(range(0, self.encoded_table.attrs.val_split_index))
+ self.first_index = 0
elif split == "val":
- self.indices = list(
- range(
- self.encoded_table.attrs.val_split_index, self.encoded_table.nrows
- )
- )
+ self.first_index = self.encoded_table.attrs.val_split_index
elif split is None:
- self.indices = list(range(0, self.encoded_table.nrows))
+ self.first_index = 0
else:
raise NotImplementedError
@@ -56,33 +52,41 @@ def __init__(self, data_folder, h5_file, split, n_moves=None, **unused):
if n_moves is not None:
# This is the same as min(MAX_MOVE_SEQUENCE_LENGTH, n_moves)
self.n_moves = min(
- len(self.encoded_table[self.indices[0]]["moves"]) - 1, n_moves
+ len(self.encoded_table[self.first_index]["moves"]) - 1, n_moves
)
else:
- self.n_moves = len(self.encoded_table[self.indices[0]]["moves"]) - 1
+ self.n_moves = len(self.encoded_table[self.first_index]["moves"]) - 1
def __getitem__(self, i):
- turns = torch.IntTensor([self.encoded_table[self.indices[i]]["turn"]])
+ turns = torch.IntTensor([self.encoded_table[self.first_index + i]["turn"]])
white_kingside_castling_rights = torch.IntTensor(
- [self.encoded_table[self.indices[i]]["white_kingside_castling_rights"]]
+ [self.encoded_table[self.first_index + i]["white_kingside_castling_rights"]]
) # (1)
white_queenside_castling_rights = torch.IntTensor(
- [self.encoded_table[self.indices[i]]["white_queenside_castling_rights"]]
+ [
+ self.encoded_table[self.first_index + i][
+ "white_queenside_castling_rights"
+ ]
+ ]
) # (1)
black_kingside_castling_rights = torch.IntTensor(
- [self.encoded_table[self.indices[i]]["black_kingside_castling_rights"]]
+ [self.encoded_table[self.first_index + i]["black_kingside_castling_rights"]]
) # (1)
black_queenside_castling_rights = torch.IntTensor(
- [self.encoded_table[self.indices[i]]["black_queenside_castling_rights"]]
+ [
+ self.encoded_table[self.first_index + i][
+ "black_queenside_castling_rights"
+ ]
+ ]
) # (1)
board_position = torch.IntTensor(
- self.encoded_table[self.indices[i]]["board_position"]
+ self.encoded_table[self.first_index + i]["board_position"]
) # (64)
moves = torch.LongTensor(
- self.encoded_table[self.indices[i]]["moves"][: self.n_moves + 1]
+ self.encoded_table[self.first_index + i]["moves"][: self.n_moves + 1]
) # (n_moves + 1)
length = torch.LongTensor(
- [self.encoded_table[self.indices[i]]["length"]]
+ [self.encoded_table[self.first_index + i]["length"]]
).clamp(
max=self.n_moves
) # (1), value <= n_moves
@@ -99,7 +103,14 @@ def __getitem__(self, i):
}
def __len__(self):
- return len(self.indices)
+ if self.split == "train":
+ return self.encoded_table.attrs.val_split_index
+ elif self.split == "val":
+ return self.encoded_table.nrows - self.encoded_table.attrs.val_split_index
+ elif self.split is None:
+ return self.encoded_table.nrows
+ else:
+ raise NotImplementedError
class ChessDatasetFT(Dataset):
@@ -175,12 +186,11 @@ def __len__(self):
elif self.split == "val":
return self.encoded_table.nrows - self.encoded_table.attrs.val_split_index
elif self.split is None:
- self.encoded_table.nrows
+ return self.encoded_table.nrows
else:
raise NotImplementedError
-
if __name__ == "__main__":
# Get configuration
parser = argparse.ArgumentParser()
diff --git a/chess_transformers/train/train.py b/chess_transformers/train/train.py
index b362978..378f22e 100644
--- a/chess_transformers/train/train.py
+++ b/chess_transformers/train/train.py
@@ -5,7 +5,7 @@
import torch.backends.cudnn as cudnn
from tqdm import tqdm
-from torch.cuda.amp import GradScaler
+from torch.amp import GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
@@ -96,7 +96,7 @@ def train_model(CONFIG):
criterion = criterion.to(DEVICE)
# AMP scaler
- scaler = GradScaler(enabled=CONFIG.USE_AMP)
+ scaler = GradScaler(device=DEVICE, enabled=CONFIG.USE_AMP)
# Find total epochs to train
epochs = (CONFIG.N_STEPS // (len(train_loader) // CONFIG.BATCHES_PER_STEP)) + 1
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..62a4d5e
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,44 @@
+[build-system]
+requires = ["setuptools >= 64"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "chess-transformers"
+version = "0.3.1"
+description = "Teaching transformers to play chess."
+authors = [{ name = "Sagar Vinodababu", email = "sgrvinod@gmail.com" }]
+maintainers = [{ name = "Sagar Vinodababu", email = "sgrvinod@gmail.com" }]
+readme = "README.md"
+requires-python = ">=3.6.0"
+dependencies = [
+ "beautifulsoup4==4.12.3",
+ "chess==1.10.0",
+ "colorama==0.4.5",
+ "ipython==8.17.2",
+ "Markdown==3.3.4",
+ "py_cpuinfo==9.0.0",
+ "regex==2024.7.24",
+ "scipy==1.13.1",
+ "setuptools==69.0.3",
+ "tables==3.9.2",
+ "tabulate==0.9.0",
+ "torch==2.4.0",
+ "tqdm==4.64.1",
+ "tensorboard==2.18.0",
+]
+license = { text = "MIT License" }
+keywords = ["transformer", "chess", "pytorch", "deep learning", "chess engine"]
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python :: 3.12",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+]
+
+[project.urls]
+homepage = "https://github.com/sgrvinod/chess-transformers"
+source = "https://github.com/sgrvinod/chess-transformers"
+changelog = "https://github.com/sgrvinod/chess-transformers/blob/main/CHANGELOG.md"
+releasenotes = "https://github.com/sgrvinod/chess-transformers/releases"
+issues = "https://github.com/sgrvinod/chess-transformers/issues"
diff --git a/setup.py b/setup.py
index 6ef1111..45642b9 100644
--- a/setup.py
+++ b/setup.py
@@ -1,43 +1,3 @@
from setuptools import setup, find_packages
-with open("README.md", mode="r", encoding="utf-8") as readme_file:
- readme = readme_file.read()
-
-
-setup(
- name="chess-transformers",
- version="0.3.0",
- author="Sagar Vinodababu",
- author_email="sgrvinod@gmail.com",
- description="Chess Transformers",
- long_description=readme,
- long_description_content_type="text/markdown",
- license="MIT License",
- url="https://github.com/sgrvinod/chess-transformers",
- download_url="https://github.com/sgrvinod/chess-transformers",
- packages=find_packages(),
- python_requires=">=3.6.0",
- install_requires=[
- "beautifulsoup4==4.12.3",
- "chess==1.10.0",
- "colorama==0.4.5",
- "ipython==8.17.2",
- "Markdown==3.3.4",
- "py_cpuinfo==9.0.0",
- "regex==2024.7.24",
- "scipy==1.13.1",
- "setuptools==69.0.3",
- "tables==3.9.2",
- "tabulate==0.9.0",
- "torch==2.4.0",
- "tqdm==4.64.1",
- ],
- classifiers=[
- "Development Status :: 3 - Alpha",
- "Intended Audience :: Science/Research",
- "License :: OSI Approved :: MIT License",
- "Programming Language :: Python :: 3.9",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
- ],
- keywords="transformer networks chess pytorch deep learning",
-)
+setup(packages=find_packages())