Skip to content

Commit

Permalink
using tonic in place of spikedata
Browse files Browse the repository at this point in the history
* adapted notebooks to use tonic instead of the deprecated `spikedata` module

* added a `DATADIR` variable to let user choose their data directory (`/tmp` has the nice property of getting easily flushed..)

*  fixed an error to a call of a deprecated torch function: `AttributeError: module 'torch' has no attribute '_six'`

* removed all outputs to reduce size of notebooks
  • Loading branch information
laurentperrinet committed Dec 5, 2023
1 parent a65ee76 commit d723f37
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 341 deletions.
246 changes: 143 additions & 103 deletions examples/dataloaders/DVS_Gesture.ipynb
Original file line number Diff line number Diff line change
@@ -1,69 +1,92 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "DVS_Gesture.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K7t_qlnTVWWS"
},
"outputs": [],
"source": [
"!pip install snntorch"
],
"execution_count": null,
"outputs": []
"# have you installed snn torch?\n",
"# %pip install snntorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YpJrTa3XVaSQ"
},
"outputs": [],
"source": [
"import snntorch as snn\n",
"from snntorch.spikevision import spikedata \n",
"from torch.utils.data import DataLoader"
],
"import snntorch as snn\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": []
"metadata": {},
"outputs": [],
"source": [
"DATADIR = \"/tmp/data\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N7kYJ5reVeDA"
},
"metadata": {},
"source": [
"## Download Dataset"
"## Download Dataset using `spikedata` (deprecated)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kbHJ827iVcYY"
},
"outputs": [],
"source": [
"# note that a default transform is already applied to keep things easy\n",
"# from snntorch.spikevision import spikedata \n",
"# # note that a default transform is already applied to keep things easy\n",
"\n",
"dvs_train = spikedata.DVSGesture(\"data/dvsgesture\", train=True, dt=1000, num_steps=500, ds=1) # ds: spatial compression; dt: temporal compressiondvs_test\n",
"dvs_test = spikedata.DVSGesture(\"data/dvsgesture\", train=False, dt=1000, num_steps=1800, ds=1)"
],
"# train_ds = spikedata.DVSGesture(\"/tmp/data/dvsgesture\", train=True, dt=1000, num_steps=500, ds=1) # ds: spatial compression; dt: temporal compressiondvs_test\n",
"# test_ds = spikedata.DVSGesture(\"/tmp/data/dvsgesture\", train=False, dt=1000, num_steps=1800, ds=1)\n",
"# test_ds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download Dataset using `tonic`"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": []
"metadata": {},
"outputs": [],
"source": [
"import tonic\n",
"import tonic.transforms as transforms\n",
"\n",
"sensor_size = tonic.datasets.DVSGesture.sensor_size\n",
"\n",
"# Denoise removes isolated, one-off events\n",
"# time_window\n",
"frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),\n",
" transforms.ToFrame(sensor_size=sensor_size,\n",
" time_window=1000)\n",
" ])\n",
"\n",
"train_ds = tonic.datasets.DVSGesture(save_to=DATADIR, transform=frame_transform, train=True)\n",
"test_ds = tonic.datasets.DVSGesture(save_to=DATADIR, transform=frame_transform, train=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -72,22 +95,9 @@
"id": "AWQTG9lhVmKY",
"outputId": "3d607d3e-b416-4a47-e61e-4d562ed52dba"
},
"outputs": [],
"source": [
"dvs_test"
],
"execution_count": 1,
"outputs": [
{
"output_type": "error",
"ename": "NameError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-1b6453ff6f37>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdvs_test\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'dvs_test' is not defined"
]
}
"test_ds"
]
},
{
Expand All @@ -101,93 +111,123 @@
},
{
"cell_type": "code",
"metadata": {
"id": "TWhGAeVLVws5"
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dvs_train_dl = DataLoader(dvs_train, shuffle=True, batch_size=64)\n",
"dvs_test_dl = DataLoader(dvs_test, shuffle=False, batch_size=64)"
],
"from torch.utils.data import DataLoader\n",
"\n",
"train_dl = DataLoader(train_ds, shuffle=True, batch_size=64)\n",
"test_dl = DataLoader(test_ds, shuffle=False, batch_size=64)"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": []
"metadata": {},
"outputs": [],
"source": [
"print('the number of items in the dataset is', len(train_dl.dataset))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with Data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "MhwnXKs_VxgZ"
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get a feel for the data\n",
"dvs_train_dl.dataset[0][0].size()"
],
"i_item = 42 # random index into a sample\n",
"data, label = train_dl.dataset[i_item]\n",
"import torch\n",
"data = torch.Tensor(data)\n",
"\n",
"print('The data sample has size', data.shape)\n",
"print(f\"in case you're blind AF, the target is: {label} ({train_ds.classes[label]})\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": []
"metadata": {},
"outputs": [],
"source": [
"train_ds.classes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i8mEmv8PV9qK"
},
"metadata": {},
"source": [
"## Visualization"
"## Visualize"
]
},
{
"cell_type": "code",
"metadata": {
"id": "iTt_4BlqV3vP"
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n = 3 # random index into a sample\n",
"import matplotlib.pyplot as plt\n",
"import snntorch.spikeplot as splt\n",
"from IPython.display import HTML, display\n",
"import numpy as np\n",
"\n",
"# flatten on-spikes and off-spikes into one channel\n",
"a = (dvs_train_dl.dataset[n][0][:, 0] + dvs_train_dl.dataset[n][0][:, 1])\n",
"print(f\"in case you're blind AF, the target is: {spikedata.dvs_gesture.mapping[dvs_train_dl.dataset[n][1]]}\")\n",
"\n",
"# a = (train_dl.dataset[n][0][:, 0] + train_dl.dataset[n][0][:, 1])\n",
"a = (data[:300, 0, :, :] - data[:300, 1, :, :])\n",
"# a = np.swapaxes(a, 0, -1)\n",
"# Plot\n",
"fig, ax = plt.subplots()\n",
"anim = splt.animator(a, fig, ax, interval=10)\n",
"anim = splt.animator(a, fig, ax, interval=200)\n",
"HTML(anim.to_html5_video())\n",
"\n",
"# anim.save('nmnist_animation.mp4', writer = 'ffmpeg', fps=50) "
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iDDqdJ1YWBns"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JpiRF7pqV9AJ"
},
"source": [
"plt.imshow(torch.sum(dvs_train_dl.dataset[4][0][:300,0,:,:], axis=0), cmap='hot')\n",
"import torch\n",
"plt.imshow(torch.sum(data[:300, 0,:,:], axis=0), cmap='hot')\n",
"plt.colorbar()"
],
"execution_count": null,
"outputs": []
]
}
],
"metadata": {
"colab": {
"name": "DVS_Gesture.ipynb",
"provenance": []
},
{
"cell_type": "code",
"metadata": {
"id": "7gZariHNWH-X"
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"source": [
"spikedata.dvs_gesture.mapping"
],
"execution_count": null,
"outputs": []
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
]
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading

0 comments on commit d723f37

Please sign in to comment.