Skip to content

Commit

Permalink
add fullsizeof
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Nov 29, 2019
1 parent 720737e commit fbf8e17
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,6 @@ You may not put data on the GPU from any other thread than the main thread, henc

If `T` is omitted while wrapping a channel, it is assumed that `f(eltype(dataset)) == typeof(f(::eltype(dataset)))` or in words, `f` must have a method returning the type resulting from applying `f` to an element of the wrapped channel.

By having `f=cu` or `f=gpu` which puts data on a GPU, you now have an efficient way of training models on the GPU, while reading data in a separate thread. Primitive benchmarking showed some 0-20% performance improvement using this strategy over putting data on the GPU as it is taken out of the dataset. If `cu/gpu` become thread-safe, this improvement may become larger.
By having `f=cu` or `f=gpu` which puts data on a GPU, you now have an efficient way of training models on the GPU, while reading data in a separate thread. Primitive benchmarking showed some 0-20% performance improvement using this strategy over putting data on the GPU as it is taken out of the dataset. If `cu/gpu` become thread-safe, this improvement may become larger.

*Note:* If your entire dataset fit onto the GPU and you do not run out of memory while performing backpropagation, the fastest method is *by far* to keep all data on the GPU during the entire training. You can try by simply `gpu.(collect(dataset))` or `collect(dataset)` if the channel already puts data on the GPU. The function `fullsizeof(dataset::LengthChannel)` will tell you the size in bytes required to `collect` the dataset.
12 changes: 11 additions & 1 deletion src/LengthChannels.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module LengthChannels
export LengthChannel
export LengthChannel, fullsizeof

"""
This package defines a type `LengthChannel{T} <: AbstractChannel{T}` which simply adds information about the length of the channel when it is iterated. The constructor behaves the same as the constructor for `Channel`, but takes an additional integer specifying the length. This length is not to be confused with the buffer size of the channel, referred to as `buf` in the example below. When a `LengthChannel` is iterated, it continues until it has iterated the specified number of elements, after that the channel is closed, even if there are more elements put in the channel.
Expand Down Expand Up @@ -126,6 +126,16 @@ end

Base.length(lc::LengthChannel) = lc.l

"""
fullsizeof(lc::LengthChannel)
Return the sum of `sizeof` over all elements, assuming they all have the same `sizeof`.
"""
function fullsizeof(lc::LengthChannel)
isready(lc) || error("Channel is not ready")
sizeof(fetch(lc))*length(lc)
end

for f in (bind, close, fetch, isopen, isready, lock, popfirst!, push!, put!, take!, trylock, unlock, wait, eltype)
f = nameof(f)
@eval Base.$f(lc::LengthChannel, args...; kwargs...) = $(f)(lc.ch, args...; kwargs...)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using Test
end
end

@test fullsizeof(lc) == l*sizeof(1)

@test eltype(lc) <: Int

@test length(lc) == l
Expand Down

0 comments on commit fbf8e17

Please sign in to comment.