diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index ce2886e..c683ff0 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -52,10 +52,12 @@ function FixedEffects.gather!(fecoef::CuVector, refs::CuVector, α::Number, y::C end function gather_kernel!(fecoef, refs, α, y, cache) - index = (blockIdx().x - 1) * blockDim().x + threadIdx().x + index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x stride = blockDim().x * gridDim().x - @inbounds for i = index:stride:length(y) + i = index + @inbounds while i <= length(y) CUDA.@atomic fecoef[refs[i]] += α * y[i] * cache[i] + i += stride end end @@ -65,10 +67,12 @@ function FixedEffects.scatter!(y::CuVector, α::Number, fecoef::CuVector, refs:: end function scatter_kernel!(y, α, fecoef, refs, cache) - index = (blockIdx().x - 1) * blockDim().x + threadIdx().x + index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x stride = blockDim().x * gridDim().x - @inbounds for i = index:stride:length(y) + i = index + @inbounds while i <= length(y) y[i] += α * fecoef[refs[i]] * cache[i] + i += stride end end @@ -128,10 +132,12 @@ function scale!(scale::CuVector, refs::CuVector, interaction::CuVector, weights: end function scale_kernel!(scale, refs, interaction, weights) - index = (blockIdx().x - 1) * blockDim().x + threadIdx().x + index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x stride = blockDim().x * gridDim().x - @inbounds for i = index:stride:length(interaction) + i = index + @inbounds while i <= length(interaction) CUDA.@atomic scale[refs[i]] += abs2(interaction[i]) * weights[i] + i += stride end end @@ -141,13 +147,12 @@ function cache!(cache::CuVector, refs::CuVector, interaction::CuVector, weights: end function cache!_kernel!(cache, refs, interaction, weights, scale) - index = (blockIdx().x - 1) * blockDim().x + threadIdx().x + index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x stride = blockDim().x * gridDim().x - @inbounds for i = index:stride:length(cache) + i = index + @inbounds while i <= length(cache) cache[i] = interaction[i] * sqrt(weights[i]) * scale[refs[i]] + i += stride end end - - - end