diff --git a/README.md b/README.md index 5de778f..3ef234a 100644 --- a/README.md +++ b/README.md @@ -73,4 +73,6 @@ You may not put data on the GPU from any other thread than the main thread, henc If `T` is omitted while wrapping a channel, it is assumed that `f(eltype(dataset)) == typeof(f(::eltype(dataset)))` or in words, `f` must have a method returning the type resulting from applying `f` to an element of the wrapped channel. -By having `f=cu` or `f=gpu` which puts data on a GPU, you now have an efficient way of training models on the GPU, while reading data in a separate thread. Primitive benchmarking showed some 0-20% performance improvement using this strategy over putting data on the GPU as it is taken out of the dataset. If `cu/gpu` become thread-safe, this improvement may become larger. +By having `f=cu` or `f=gpu` which puts data on a GPU, you now have an efficient way of training models on the GPU, while reading data in a separate thread. Primitive benchmarking showed some 0-20% performance improvement using this strategy over putting data on the GPU as it is taken out of the dataset. If `cu/gpu` become thread-safe, this improvement may become larger. + +*Note:* If your entire dataset fit onto the GPU and you do not run out of memory while performing backpropagation, the fastest method is *by far* to keep all data on the GPU during the entire training. You can try by simply `gpu.(collect(dataset))` or `collect(dataset)` if the channel already puts data on the GPU. The function `fullsizeof(dataset::LengthChannel)` will tell you the size in bytes required to `collect` the dataset. diff --git a/src/LengthChannels.jl b/src/LengthChannels.jl index 0619018..8486d7b 100644 --- a/src/LengthChannels.jl +++ b/src/LengthChannels.jl @@ -1,5 +1,5 @@ module LengthChannels -export LengthChannel +export LengthChannel, fullsizeof """ This package defines a type `LengthChannel{T} <: AbstractChannel{T}` which simply adds information about the length of the channel when it is iterated. The constructor behaves the same as the constructor for `Channel`, but takes an additional integer specifying the length. This length is not to be confused with the buffer size of the channel, referred to as `buf` in the example below. When a `LengthChannel` is iterated, it continues until it has iterated the specified number of elements, after that the channel is closed, even if there are more elements put in the channel. @@ -126,6 +126,16 @@ end Base.length(lc::LengthChannel) = lc.l +""" + fullsizeof(lc::LengthChannel) + +Return the sum of `sizeof` over all elements, assuming they all have the same `sizeof`. +""" +function fullsizeof(lc::LengthChannel) + isready(lc) || error("Channel is not ready") + sizeof(fetch(lc))*length(lc) +end + for f in (bind, close, fetch, isopen, isready, lock, popfirst!, push!, put!, take!, trylock, unlock, wait, eltype) f = nameof(f) @eval Base.$f(lc::LengthChannel, args...; kwargs...) = $(f)(lc.ch, args...; kwargs...) diff --git a/test/runtests.jl b/test/runtests.jl index 73eb259..aaba9ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,8 @@ using Test end end + @test fullsizeof(lc) == l*sizeof(1) + @test eltype(lc) <: Int @test length(lc) == l