-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
final edit, finishing touches. Python perf test: SPS: 2,045,564.37560…
…13566 on mediocre hardware. C perf test: SPS: 333299.687500. added README. formatted w/ clang-tidy Google for succinctness and readability.
- Loading branch information
xinpw8
committed
Jan 12, 2025
1 parent
791bfe5
commit fcba83c
Showing
6 changed files
with
1,555 additions
and
1,550 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# README | ||
Please star PufferLib on GitHub: it really makes a difference! | ||
https://github.com/pufferai/pufferlib | ||
|
||
by Daniel Addis, 2024 | ||
https://github.com/xinpw8 | ||
|
||
## Puffer Ocean Enduro | ||
This project contains a performant reinforcement-learning environment inspired by the classic Atari 2600 game Enduro. It uses C, Cython, and Python to provide an interactive RL environment, trainable with PufferLib. | ||
|
||
## Building & Setup | ||
1. Install dependencies. All commands should be run from the `PufferLib` top directory. | ||
```sh | ||
pip install -e .'[cleanrl]' | ||
``` | ||
Instructions on https://puffer.ai/docs.html | ||
|
||
2. Compilation is run when pufferlib is pip installed. After making changes to `cy_enduro.pyx`, `enduro.c`, or `enduro.h`, recompile with: | ||
```sh | ||
python setup.py build_ext --inplace | ||
``` | ||
|
||
3. To locally compile the C environment for testing (without Cython), run: | ||
```sh | ||
scripts/build_ocean.sh enduro local | ||
``` | ||
This builds a runnable `enduro` module in the PufferLib top directory, which you can run with: | ||
```sh | ||
./enduro | ||
``` | ||
Hold Shift to take control from the agent. | ||
|
||
## Training | ||
To train using the demo script with wandb logs, run: | ||
```sh | ||
python demo.py --env puffer_enduro --mode train --track | ||
``` | ||
Model files are saved at intervals specified in `config/ocean/enduro.ini` to the `experiments/` directory. | ||
|
||
## Evaluation | ||
To evaluate a local checkpoint, run: | ||
```sh | ||
python demo.py --env puffer_enduro --mode eval --eval-model-path your_model.pt | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,92 @@ | ||
#include "enduro.h" | ||
|
||
#include <stddef.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include <stddef.h> | ||
#include <time.h> | ||
#include "enduro.h" | ||
#include "raylib.h" | ||
#include "puffernet.h" | ||
|
||
#define MAX_ENEMIES 10 | ||
#include "puffernet.h" | ||
#include "raylib.h" | ||
|
||
void get_input(Enduro* env) { | ||
if ((IsKeyDown(KEY_DOWN) && IsKeyDown(KEY_RIGHT)) || (IsKeyDown(KEY_S) && IsKeyDown(KEY_D))) { | ||
env->actions[0] = ACTION_DOWNRIGHT; // Decelerate and move right | ||
} else if ((IsKeyDown(KEY_DOWN) && IsKeyDown(KEY_LEFT)) || (IsKeyDown(KEY_S) && IsKeyDown(KEY_A))) { | ||
env->actions[0] = ACTION_DOWNLEFT; // Decelerate and move left | ||
} else if (IsKeyDown(KEY_SPACE) && (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D))) { | ||
env->actions[0] = ACTION_RIGHTFIRE; // Accelerate and move right | ||
} else if (IsKeyDown(KEY_SPACE) && (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A))) { | ||
env->actions[0] = ACTION_LEFTFIRE; // Accelerate and move left | ||
} else if (IsKeyDown(KEY_SPACE)) { | ||
env->actions[0] = ACTION_FIRE; // Accelerate | ||
} else if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) { | ||
env->actions[0] = ACTION_DOWN; // Decelerate | ||
} else if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) { | ||
env->actions[0] = ACTION_LEFT; // Move left | ||
} else if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) { | ||
env->actions[0] = ACTION_RIGHT; // Move right | ||
} else { | ||
env->actions[0] = ACTION_NOOP; // No action | ||
} | ||
void get_input(Enduro *env) { | ||
if ((IsKeyDown(KEY_DOWN) && IsKeyDown(KEY_RIGHT)) || | ||
(IsKeyDown(KEY_S) && IsKeyDown(KEY_D))) { | ||
env->actions[0] = ACTION_DOWNRIGHT; // Decelerate and move right | ||
} else if ((IsKeyDown(KEY_DOWN) && IsKeyDown(KEY_LEFT)) || | ||
(IsKeyDown(KEY_S) && IsKeyDown(KEY_A))) { | ||
env->actions[0] = ACTION_DOWNLEFT; // Decelerate and move left | ||
} else if (IsKeyDown(KEY_SPACE) && | ||
(IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D))) { | ||
env->actions[0] = ACTION_RIGHTFIRE; // Accelerate and move right | ||
} else if (IsKeyDown(KEY_SPACE) && | ||
(IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A))) { | ||
env->actions[0] = ACTION_LEFTFIRE; // Accelerate and move left | ||
} else if (IsKeyDown(KEY_SPACE)) { | ||
env->actions[0] = ACTION_FIRE; // Accelerate | ||
} else if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) { | ||
env->actions[0] = ACTION_DOWN; // Decelerate | ||
} else if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) { | ||
env->actions[0] = ACTION_LEFT; // Move left | ||
} else if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) { | ||
env->actions[0] = ACTION_RIGHT; // Move right | ||
} else { | ||
env->actions[0] = ACTION_NOOP; // No action | ||
} | ||
} | ||
|
||
int demo() { | ||
Weights* weights = load_weights("resources/enduro/0105enduro_weights.bin", 142218); | ||
LinearLSTM* net = make_linearlstm(weights, 1, 68, 9); | ||
|
||
Enduro env = { | ||
.num_envs = 1, | ||
.max_enemies = MAX_ENEMIES, | ||
.obs_size = OBSERVATIONS_MAX_SIZE | ||
}; | ||
|
||
allocate(&env); | ||
GameState* client = make_client(&env); | ||
unsigned int seed = 0; | ||
init(&env, seed, 0); | ||
reset(&env); | ||
Weights *weights = | ||
load_weights("resources/enduro/enduro_weights.bin", 142218); | ||
LinearLSTM *net = make_linearlstm(weights, 1, 68, 9); | ||
|
||
while (!WindowShouldClose()) { | ||
if (IsKeyDown(KEY_LEFT_SHIFT)) { | ||
get_input(&env); | ||
} else { | ||
forward_linearlstm(net, env.observations, env.actions); | ||
} | ||
Enduro env = {.obs_size = OBSERVATIONS_MAX_SIZE}; | ||
allocate(&env); | ||
GameState *client = make_client(&env); | ||
unsigned int seed = 0; | ||
init(&env, seed, 0); | ||
reset(&env); | ||
|
||
c_step(&env); | ||
render(client, &env); | ||
while (!WindowShouldClose()) { | ||
if (IsKeyDown(KEY_LEFT_SHIFT)) { | ||
get_input(&env); | ||
} else { | ||
forward_linearlstm(net, env.observations, env.actions); | ||
} | ||
|
||
free_linearlstm(net); | ||
free(weights); | ||
close_client(client); | ||
free_allocated(&env); | ||
return 0; | ||
c_step(&env); | ||
render(client, &env); | ||
} | ||
|
||
free_linearlstm(net); | ||
free(weights); | ||
close_client(client); | ||
free_allocated(&env); | ||
return 0; | ||
} | ||
|
||
void perftest(float test_time) { | ||
Enduro env = { | ||
.num_envs = 1, | ||
.max_enemies = MAX_ENEMIES, | ||
.obs_size = OBSERVATIONS_MAX_SIZE | ||
}; | ||
|
||
allocate(&env); | ||
Enduro env = {.obs_size = OBSERVATIONS_MAX_SIZE}; | ||
allocate(&env); | ||
|
||
unsigned int seed = 12345; | ||
init(&env, seed, 0); | ||
reset(&env); | ||
unsigned int seed = 12345; | ||
init(&env, seed, 0); | ||
reset(&env); | ||
|
||
int start = time(NULL); | ||
int i = 0; | ||
while (time(NULL) - start < test_time) { | ||
env.actions[0] = rand()%9; | ||
c_step(&env); | ||
i++; | ||
} | ||
int start = time(NULL); | ||
int i = 0; | ||
while (time(NULL) - start < test_time) { | ||
env.actions[0] = rand() % 9; | ||
c_step(&env); | ||
i++; | ||
} | ||
|
||
int end = time(NULL); | ||
printf("SPS: %f\n", i / (float)(end - start)); | ||
free_allocated(&env); | ||
int end = time(NULL); | ||
printf("SPS: %f\n", i / (float)(end - start)); | ||
free_allocated(&env); | ||
} | ||
|
||
int main() { | ||
demo(); | ||
// perftest(20.0f); | ||
return 0; | ||
demo(); | ||
// perftest(20.0f); | ||
return 0; | ||
} |
Oops, something went wrong.