Skip to content

Commit

Permalink
Merge pull request #17 from sgrvinod/0.3.1
Browse files Browse the repository at this point in the history
0.3.1
  • Loading branch information
sgrvinod authored Nov 27, 2024
2 parents 2acccd1 + aa8e611 commit 972aac3
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 74 deletions.
16 changes: 15 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
# Change Log

## v0.3.1

### Added

* A [**`pyproject.toml`**
file](https://github.com/sgrvinod/chess-transformers/blob/main/pyproject.toml) has been added in compliance with [PEP 660](https://peps.python.org/pep-0660/). While the inclusion of a `setup.py` file is not deprecated, its use as a command-line tool, such as in the legacy `setup.py develop` method for performing an editable installation is now deprecated.

### Changed

* **`chess_transformers.train.datasets.ChessDataset`** was optimized for large datasets. A list of indices for the data split is no longer maintained or indexed in the dataset.
* The `TRAINING_CHECKPOINT` parameter in each of **`chess_transformers.configs.models`** was set to `None` to reflect the correct conditions for beginning training of a model.
* Dynamic shape tracing is disabled for the compilation of [*CT-ED-45*](https://github.com/sgrvinod/chess-transformers#ct-ed-45) to prevent memory leaks as seen in [#16](https://github.com/sgrvinod/chess-transformers/issues/16).
* References to `torch.cuda.amp.GradScaler(...)` have been replaced by `torch.amp.GradScaler(device="cuda", ...)` following its deprecation.

## v0.3.0

### Added

* There are 3 new datasets: [ML23c](https://github.com/sgrvinod/chess-transformers#ml23c), [GC22c](https://github.com/sgrvinod/chess-transformers#gc22c), and [ML23d](https://github.com/sgrvinod/chess-transformers#ml23d).
* There are 3 new datasets: [*ML23c*](https://github.com/sgrvinod/chess-transformers#ml23c), [*GC22c*](https://github.com/sgrvinod/chess-transformers#gc22c), and [*ML23d*](https://github.com/sgrvinod/chess-transformers#ml23d).
* A new naming convention for datasets is used. Datasets are now named in the format "[*PGN Fileset*][*Filters*]". For example, *LE1222* is now called [*LE22ct*](https://github.com/sgrvinod/chess-transformers#le22ct), where *LE22* is the name of the PGN fileset from which this dataset was derived, and "*c*", "*t*" are filters for games that ended in checkmates and games that used a specific time control respectively.
* [*CT-EFT-85*](https://github.com/sgrvinod/chess-transformers#ct-eft-85) is a new trained model with about 85 million parameters.
* **`chess_transformers.train.utils.get_lr()`** now accepts new arguments, `schedule` and `decay`, to accomodate a new learning rate schedule: exponential decay after warmup.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<h1 align="center"><i>Chess Transformers</i></h1>
<p align="center"><i>Teaching transformers to play chess</i></p>
<p align="center"> <a href="https://github.com/sgrvinod/chess-transformers/releases/tag/v0.3.0"><img alt="Version" src="https://img.shields.io/github/v/tag/sgrvinod/chess-transformers?label=version"></a> <a href="https://github.com/sgrvinod/chess-transformers/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/sgrvinod/chess-transformers?label=license"></a></p>
<p align="center"> <a href="https://github.com/sgrvinod/chess-transformers/releases/tag/v0.3.1"><img alt="Version" src="https://img.shields.io/github/v/tag/sgrvinod/chess-transformers?label=version"></a> <a href="https://github.com/sgrvinod/chess-transformers/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/sgrvinod/chess-transformers?label=license"></a></p>
<br>

*Chess Transformers* is a library for training transformer models to play chess by learning from human games.
Expand Down
4 changes: 2 additions & 2 deletions chess_transformers/configs/models/CT-E-20.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
) # folder containing checkpoints
TRAINING_CHECKPOINT = (
NAME + ".pt"
) # path to model checkpoint to resume training, None if none
None # path to model checkpoint (NAME + ".pt") to resume training, None if none
)
CHECKPOINT_AVG_PREFIX = (
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
)
Expand Down
6 changes: 3 additions & 3 deletions chess_transformers/configs/models/CT-ED-45.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
N_MOVES = 10 # expected maximum length of move sequences in the model, <= MAX_MOVE_SEQUENCE_LENGTH
DISABLE_COMPILATION = False # disable model compilation?
COMPILATION_MODE = "default" # mode of model compilation (see torch.compile())
DYNAMIC_COMPILATION = True # expect tensors with dynamic shapes?
DYNAMIC_COMPILATION = False # expect tensors with dynamic shapes?
SAMPLING_K = 1 # k in top-k sampling model predictions during play
MODEL = ChessTransformer # custom PyTorch model to train

Expand Down Expand Up @@ -94,8 +94,8 @@
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
) # folder containing checkpoints
TRAINING_CHECKPOINT = (
NAME + ".pt"
) # path to model checkpoint to resume training, None if none
None # path to model checkpoint (NAME + ".pt") to resume training, None if none
)
CHECKPOINT_AVG_PREFIX = (
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
)
Expand Down
4 changes: 2 additions & 2 deletions chess_transformers/configs/models/CT-EFT-20.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
) # folder containing checkpoints
TRAINING_CHECKPOINT = (
NAME + ".pt"
) # path to model checkpoint to resume training, None if none
None # path to model checkpoint (NAME + ".pt") to resume training, None if none
)
CHECKPOINT_AVG_PREFIX = (
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
)
Expand Down
4 changes: 3 additions & 1 deletion chess_transformers/configs/models/CT-EFT-85.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@
CHECKPOINT_FOLDER = str(
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME
) # folder containing checkpoints
TRAINING_CHECKPOINT = None # path to model checkpoint to resume training, None if none
TRAINING_CHECKPOINT = (
None # path to model checkpoint (NAME + ".pt") to resume training, None if none
)
AVERAGE_STEPS = {491000, 492500, 494000, 495500, 497000, 498500, 500000}
CHECKPOINT_AVG_PREFIX = (
"step" # prefix to add to checkpoint name when saving checkpoints for averaging
Expand Down
52 changes: 31 additions & 21 deletions chess_transformers/train/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,15 @@ def __init__(self, data_folder, h5_file, split, n_moves=None, **unused):
# Open table in H5 file
self.h5_file = tb.open_file(os.path.join(data_folder, h5_file), mode="r")
self.encoded_table = self.h5_file.root.encoded_data
self.split = split

# Create indices
# TODO: optimize by using a start_index and not a list of indices
if split == "train":
self.indices = list(range(0, self.encoded_table.attrs.val_split_index))
self.first_index = 0
elif split == "val":
self.indices = list(
range(
self.encoded_table.attrs.val_split_index, self.encoded_table.nrows
)
)
self.first_index = self.encoded_table.attrs.val_split_index
elif split is None:
self.indices = list(range(0, self.encoded_table.nrows))
self.first_index = 0
else:
raise NotImplementedError

Expand All @@ -56,33 +52,41 @@ def __init__(self, data_folder, h5_file, split, n_moves=None, **unused):
if n_moves is not None:
# This is the same as min(MAX_MOVE_SEQUENCE_LENGTH, n_moves)
self.n_moves = min(
len(self.encoded_table[self.indices[0]]["moves"]) - 1, n_moves
len(self.encoded_table[self.first_index]["moves"]) - 1, n_moves
)
else:
self.n_moves = len(self.encoded_table[self.indices[0]]["moves"]) - 1
self.n_moves = len(self.encoded_table[self.first_index]["moves"]) - 1

def __getitem__(self, i):
turns = torch.IntTensor([self.encoded_table[self.indices[i]]["turn"]])
turns = torch.IntTensor([self.encoded_table[self.first_index + i]["turn"]])
white_kingside_castling_rights = torch.IntTensor(
[self.encoded_table[self.indices[i]]["white_kingside_castling_rights"]]
[self.encoded_table[self.first_index + i]["white_kingside_castling_rights"]]
) # (1)
white_queenside_castling_rights = torch.IntTensor(
[self.encoded_table[self.indices[i]]["white_queenside_castling_rights"]]
[
self.encoded_table[self.first_index + i][
"white_queenside_castling_rights"
]
]
) # (1)
black_kingside_castling_rights = torch.IntTensor(
[self.encoded_table[self.indices[i]]["black_kingside_castling_rights"]]
[self.encoded_table[self.first_index + i]["black_kingside_castling_rights"]]
) # (1)
black_queenside_castling_rights = torch.IntTensor(
[self.encoded_table[self.indices[i]]["black_queenside_castling_rights"]]
[
self.encoded_table[self.first_index + i][
"black_queenside_castling_rights"
]
]
) # (1)
board_position = torch.IntTensor(
self.encoded_table[self.indices[i]]["board_position"]
self.encoded_table[self.first_index + i]["board_position"]
) # (64)
moves = torch.LongTensor(
self.encoded_table[self.indices[i]]["moves"][: self.n_moves + 1]
self.encoded_table[self.first_index + i]["moves"][: self.n_moves + 1]
) # (n_moves + 1)
length = torch.LongTensor(
[self.encoded_table[self.indices[i]]["length"]]
[self.encoded_table[self.first_index + i]["length"]]
).clamp(
max=self.n_moves
) # (1), value <= n_moves
Expand All @@ -99,7 +103,14 @@ def __getitem__(self, i):
}

def __len__(self):
return len(self.indices)
if self.split == "train":
return self.encoded_table.attrs.val_split_index
elif self.split == "val":
return self.encoded_table.nrows - self.encoded_table.attrs.val_split_index
elif self.split is None:
return self.encoded_table.nrows
else:
raise NotImplementedError


class ChessDatasetFT(Dataset):
Expand Down Expand Up @@ -175,12 +186,11 @@ def __len__(self):
elif self.split == "val":
return self.encoded_table.nrows - self.encoded_table.attrs.val_split_index
elif self.split is None:
self.encoded_table.nrows
return self.encoded_table.nrows
else:
raise NotImplementedError



if __name__ == "__main__":
# Get configuration
parser = argparse.ArgumentParser()
Expand Down
4 changes: 2 additions & 2 deletions chess_transformers/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.backends.cudnn as cudnn

from tqdm import tqdm
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -96,7 +96,7 @@ def train_model(CONFIG):
criterion = criterion.to(DEVICE)

# AMP scaler
scaler = GradScaler(enabled=CONFIG.USE_AMP)
scaler = GradScaler(device=DEVICE, enabled=CONFIG.USE_AMP)

# Find total epochs to train
epochs = (CONFIG.N_STEPS // (len(train_loader) // CONFIG.BATCHES_PER_STEP)) + 1
Expand Down
44 changes: 44 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
[build-system]
requires = ["setuptools >= 64"]
build-backend = "setuptools.build_meta"

[project]
name = "chess-transformers"
version = "0.3.1"
description = "Teaching transformers to play chess."
authors = [{ name = "Sagar Vinodababu", email = "[email protected]" }]
maintainers = [{ name = "Sagar Vinodababu", email = "[email protected]" }]
readme = "README.md"
requires-python = ">=3.6.0"
dependencies = [
"beautifulsoup4==4.12.3",
"chess==1.10.0",
"colorama==0.4.5",
"ipython==8.17.2",
"Markdown==3.3.4",
"py_cpuinfo==9.0.0",
"regex==2024.7.24",
"scipy==1.13.1",
"setuptools==69.0.3",
"tables==3.9.2",
"tabulate==0.9.0",
"torch==2.4.0",
"tqdm==4.64.1",
"tensorboard==2.18.0",
]
license = { text = "MIT License" }
keywords = ["transformer", "chess", "pytorch", "deep learning", "chess engine"]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]

[project.urls]
homepage = "https://github.com/sgrvinod/chess-transformers"
source = "https://github.com/sgrvinod/chess-transformers"
changelog = "https://github.com/sgrvinod/chess-transformers/blob/main/CHANGELOG.md"
releasenotes = "https://github.com/sgrvinod/chess-transformers/releases"
issues = "https://github.com/sgrvinod/chess-transformers/issues"
42 changes: 1 addition & 41 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,3 @@
from setuptools import setup, find_packages

with open("README.md", mode="r", encoding="utf-8") as readme_file:
readme = readme_file.read()


setup(
name="chess-transformers",
version="0.3.0",
author="Sagar Vinodababu",
author_email="[email protected]",
description="Chess Transformers",
long_description=readme,
long_description_content_type="text/markdown",
license="MIT License",
url="https://github.com/sgrvinod/chess-transformers",
download_url="https://github.com/sgrvinod/chess-transformers",
packages=find_packages(),
python_requires=">=3.6.0",
install_requires=[
"beautifulsoup4==4.12.3",
"chess==1.10.0",
"colorama==0.4.5",
"ipython==8.17.2",
"Markdown==3.3.4",
"py_cpuinfo==9.0.0",
"regex==2024.7.24",
"scipy==1.13.1",
"setuptools==69.0.3",
"tables==3.9.2",
"tabulate==0.9.0",
"torch==2.4.0",
"tqdm==4.64.1",
],
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
keywords="transformer networks chess pytorch deep learning",
)
setup(packages=find_packages())

0 comments on commit 972aac3

Please sign in to comment.