Skip to content

Commit

Permalink
Updating SubjectOnDisk to be able to read frames at a stride
Browse files Browse the repository at this point in the history
  • Loading branch information
keenon committed Jul 11, 2023
1 parent 4ec6a6c commit 705d999
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
16 changes: 11 additions & 5 deletions dart/biomechanics/SubjectOnDisk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,18 +422,22 @@ std::string SubjectOnDisk::readRawOsimFileText()
///
/// On OOB access, prints an error and returns an empty vector.
std::vector<std::shared_ptr<Frame>> SubjectOnDisk::readFrames(
int trial, int startFrame, int numFramesToRead, s_t contactThreshold)
int trial,
int startFrame,
int numFramesToRead,
int stride,
s_t contactThreshold)
{
(void)trial;
(void)startFrame;
(void)stride;
(void)numFramesToRead;

std::vector<std::shared_ptr<Frame>> result;

// 1. Open the file
FILE* file = fopen(mPath.c_str(), "r");

// 2. Seek to the right place in the file to read this frame
int linearFrameStart = 0;
for (int i = 0; i < trial; i++)
{
Expand All @@ -454,11 +458,13 @@ std::vector<std::shared_ptr<Frame>> SubjectOnDisk::readFrames(
return result;
}

int offsetBytes = mDataSectionStart + (mFrameSize * linearFrameStart);
fseek(file, offsetBytes, SEEK_SET);

for (int i = 0; i < numFramesToRead; i++)
{
// 2. Seek to the right place in the file to read this frame
int offsetBytes = mDataSectionStart + (mFrameSize * linearFrameStart)
+ (i * stride * mFrameSize);
fseek(file, offsetBytes, SEEK_SET);

// 3. Allocate a buffer to hold the serialized data
std::vector<char> serializedFrame(mFrameSize);

Expand Down
1 change: 1 addition & 0 deletions dart/biomechanics/SubjectOnDisk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class SubjectOnDisk
int trial,
int startFrame,
int numFramesToRead = 1,
int stride = 1,
s_t contactThreshold = 1.0);

/// This writes a subject out to disk in a compressed and random-seekable
Expand Down
1 change: 1 addition & 0 deletions python/_nimblephysics/biomechanics/SubjectOnDisk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ Note that these are specified in the local body frame, acting on the body at its
::py::arg("trial"),
::py::arg("startFrame"),
::py::arg("numFramesToRead") = 1,
::py::arg("stride") = 1,
::py::arg("contactThreshold") = 1.0,
"This will read from disk and allocate a number of "
":code:`Frame` "
Expand Down
2 changes: 1 addition & 1 deletion stubs/_nimblephysics-stubs/biomechanics/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2048,7 +2048,7 @@ class SubjectOnDisk():
"""
This returns the timestep size for the trial requested, in seconds per frame
"""
def readFrames(self, trial: int, startFrame: int, numFramesToRead: int = 1, contactThreshold: float = 1.0) -> typing.List[Frame]:
def readFrames(self, trial: int, startFrame: int, numFramesToRead: int = 1, stride: int = 1, contactThreshold: float = 1.0) -> typing.List[Frame]:
"""
This will read from disk and allocate a number of :code:`Frame` objects. These Frame objects are assumed to be short-lived, to save working memory. For example, you might :code:`readFrames()` to construct a training batch, then immediately allow the frames to go out of scope and be released after the batch backpropagates gradient and loss. On OOB access, prints an error and returns an empty vector.
"""
Expand Down
20 changes: 20 additions & 0 deletions unittests/unit/test_SubjectOnDisk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,4 +686,24 @@ TEST(SubjectOnDisk, HAMNER_RUNNING)
EXPECT_EQ(subject.getNumTrials(), 4);
EXPECT_GT(subject.readFrames(0, 7, 10).size(), 0);
}
#endif

#ifdef ALL_TESTS
TEST(SubjectOnDisk, HAMNER_RUNNING_READ_WITH_DATA_STRIDE)
{
auto retriever = std::make_shared<utils::CompositeResourceRetriever>();
retriever->addSchemaRetriever(
"file", std::make_shared<common::LocalResourceRetriever>());
retriever->addSchemaRetriever("dart", DartResourceRetriever::create());
std::string path = retriever->getFilePath(
"dart://sample/subjectOnDisk/HamnerRunning2013Subject01.bin");

SubjectOnDisk subject(path);
EXPECT_EQ(subject.getNumTrials(), 4);
auto frames = subject.readFrames(0, 7, 10, 5);
EXPECT_GT(frames.size(), 2);
EXPECT_EQ(frames[0]->t, 7);
EXPECT_EQ(frames[1]->t, 7 + 5);
EXPECT_EQ(frames[2]->t, 7 + 10);
}
#endif

0 comments on commit 705d999

Please sign in to comment.