Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FixedEffects"
uuid = "c8885935-8500-56a7-9867-7708b20db0eb"
version = "3.1.0"
version = "3.2.0"

[deps]
GroupedArrays = "6407cd72-fade-4a84-8a1e-56e431fc1533"
Expand Down
12 changes: 7 additions & 5 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module CUDAExt
using FixedEffects, CUDA
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap, copy_internal!
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap, copy_internal!, AtomicGather
CUDA.allowscalar(false)

##############################################################################
Expand Down Expand Up @@ -36,19 +36,21 @@ mutable struct FixedEffectLinearMapCUDA{T} <: AbstractFixedEffectLinearMap{T}
fes::Vector{<:FixedEffect}
scales::Vector{<:AbstractVector}
caches::Vector{<:AbstractVector}
gathers::Vector{AtomicGather}
end

function FixedEffectLinearMapCUDA{T}(fes::Vector{<:FixedEffect}) where {T}
fes = [_cu(T, fe) for fe in fes]
scales = [CUDA.zeros(T, fe.n) for fe in fes]
caches = [CUDA.zeros(T, length(fes[1].interaction)) for fe in fes]
return FixedEffectLinearMapCUDA{T}(fes, scales, caches)
gathers = [AtomicGather() for fe in fes]
return FixedEffectLinearMapCUDA{T}(fes, scales, caches, gathers)
end

function FixedEffects.gather!(fecoef::CuVector, refs::CuVector, α::Number, y::CuVector, cache::CuVector)
function FixedEffects.gather!(fecoef::CuVector, refs::CuVector, α::Number, y::CuVector, cache::CuVector, ::AtomicGather)
nthreads = 256
nblocks = cld(length(y), nthreads)
@cuda threads=nthreads blocks=nblocks gather_kernel!(fecoef, refs, α, y, cache)
nblocks = cld(length(y), nthreads)
@cuda threads=nthreads blocks=nblocks gather_kernel!(fecoef, refs, α, y, cache)
end

function gather_kernel!(fecoef, refs, α, y, cache)
Expand Down
128 changes: 70 additions & 58 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module MetalExt
using FixedEffects, Metal
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap, copy_internal!
using FixedEffects: FixedEffectCoefficients, AbstractWeights, UnitWeights, LinearAlgebra, Adjoint, mul!, rmul!, lsmr!, AbstractFixedEffectLinearMap, copy_internal!, AtomicGather, BucketGather
Metal.allowscalar(false)

##############################################################################
Expand Down Expand Up @@ -33,71 +33,80 @@ _mtl(T::Type, w::AbstractVector) = MtlVector{T}(convert(Vector{T}, w))

mutable struct FixedEffectLinearMapMetal{T} <: AbstractFixedEffectLinearMap{T}
fes::Vector{<:FixedEffect}
scales::Vector{<:AbstractVector}
caches::Vector
scales::Vector{MtlVector{T}}
caches::Vector{MtlVector{T}}
gathers::Vector{Union{AtomicGather, BucketGather}}
end

function bucketize_refs(refs::Vector, n::Int)
function _metal_threadgroup_width()
width = Int(device().maxThreadsPerThreadgroup.width)
return prevpow(2, width)
end

function bucketize_refs(refs::AbstractVector{<:Integer}, n::Int)
# count the number of obs per group
counts = zeros(Int, n)
@inbounds for r in refs
counts[r] += 1
end
counts = zeros(Int, n)
@inbounds for r in refs
counts[r] += 1
end
# offsets is vcat(1, cumsum(counts))
offsets_mtl = Metal.@sync Metal.zeros(Int, n + 1; storage = Metal.SharedStorage)
offsets = unsafe_wrap(Array{Int}, offsets_mtl, size(offsets_mtl))
offsets[1] = 1
@inbounds for k in 1:n
offsets[k+1] = offsets[k] + counts[k]
end
offsets = Vector{Int}(undef, n + 1)
offsets[1] = 1
@inbounds for k in 1:n
offsets[k+1] = offsets[k] + counts[k]
end

perm_mtl = Metal.@sync Metal.zeros(Int, length(refs); storage = Metal.SharedStorage)
perm = unsafe_wrap(Array{Int}, perm_mtl, size(perm_mtl))
next = offsets[1:n]
@inbounds for i in eachindex(refs)
r = refs[i]
p = next[r]
perm[p] = i
next[r] = p + 1
end
return perm_mtl, offsets_mtl
perm = Vector{Int}(undef, length(refs))
next = offsets[1:n]
@inbounds for i in eachindex(refs)
r = refs[i]
p = next[r]
perm[p] = i
next[r] = p + 1
end
return MtlVector{Int}(perm), MtlVector{Int}(offsets)
end

function FixedEffectLinearMapMetal{T}(fes::Vector{<:FixedEffect}) where {T}
fes2 = [_mtl(T, fe) for fe in fes]
scales = [Metal.zeros(T, fe.n) for fe in fes]
caches = [Any[Metal.zeros(T, length(fe.refs)), Metal.zeros(Int, 1), Metal.zeros(Int, 1)] for fe in fes]
caches = [Metal.zeros(T, length(fe.refs)) for fe in fes]
G = Union{AtomicGather, BucketGather}
gathers = Vector{G}(undef, length(fes))
Threads.@threads for i in 1:length(fes)
refs = fes[i].refs
n = fes[i].n
if n < min(100_000, div(length(refs), 16))
out = bucketize_refs(refs, n)
caches[i][2] = out[1]
caches[i][3] = out[2]
# bucketize (one threadgroup per group) for low cardinality; else atomic adds
if n < min(100_000, div(length(refs), 16))
perm, offsets = bucketize_refs(refs, n)
gathers[i] = BucketGather(perm, offsets)
else
gathers[i] = AtomicGather()
end
end
return FixedEffectLinearMapMetal{T}(fes2, scales, caches)
return FixedEffectLinearMapMetal{T}(fes2, scales, caches, gathers)
end

function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::Vector)
function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::MtlVector, g::BucketGather)
n = length(fecoef)
nthreads = Int(device().maxThreadsPerThreadgroup.width)
if n < min(100_000, div(length(refs), 16))
Metal.@sync @metal threads=nthreads groups=n gather_kernel_bin!(fecoef, refs, α, y, cache[1], cache[2], cache[3], Val(nthreads))
else
nblocks = cld(length(y), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks gather_kernel!(fecoef, refs, α, y, cache[1])
end
nthreads = _metal_threadgroup_width()
Metal.@sync @metal threads=nthreads groups=n gather_kernel_bin!(fecoef, refs, α, y, cache, g.perm, g.offsets, Val(nthreads))
end

function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::MtlVector, ::AtomicGather)
nthreads = _metal_threadgroup_width()
nblocks = cld(length(y), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks gather_kernel!(fecoef, refs, α, y, cache)
end

function gather_kernel_bin!(fecoef, refs, α, y, cache, perm, offsets, ::Val{NT}) where {NT}
k = Int(threadgroup_position_in_grid().x) # 1..K (Julia-style indexing) :contentReference[oaicite:2]{index=2}
tid = Int(thread_position_in_threadgroup().x) # 1..nthreads :contentReference[oaicite:3]{index=3}
nt = Int(threads_per_threadgroup().x) # nthreads :contentReference[oaicite:4]{index=4}
k = Int(threadgroup_position_in_grid().x)
tid = Int(thread_position_in_threadgroup().x)
nt = Int(threads_per_threadgroup().x)

# threadgroup scratch
T = eltype(fecoef)
shared = Metal.MtlThreadGroupArray(T, NT) # threadgroup-local array :contentReference[oaicite:5]{index=5}
shared = Metal.MtlThreadGroupArray(T, NT)

start = @inbounds offsets[k]
stop = @inbounds offsets[k+1] - 1
Expand All @@ -113,7 +122,7 @@ function gather_kernel_bin!(fecoef, refs, α, y, cache, perm, offsets, ::Val{NT}
end

@inbounds shared[tid] = acc
Metal.threadgroup_barrier(Metal.MemoryFlagThreadGroup) # sync + tg fence :contentReference[oaicite:6]{index=6}
Metal.threadgroup_barrier(Metal.MemoryFlagThreadGroup)

# tree reduction in shared memory
offset = nt ÷ 2
Expand Down Expand Up @@ -141,10 +150,10 @@ function gather_kernel!(fecoef, refs, α, y, cache)
return nothing
end

function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache::Vector)
nthreads = Int(device().maxThreadsPerThreadgroup.width)
function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache::MtlVector)
nthreads = _metal_threadgroup_width()
nblocks = cld(length(y), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks scatter_kernel!(y, α, fecoef, refs, cache[1])
Metal.@sync @metal threads=nthreads groups=nblocks scatter_kernel!(y, α, fecoef, refs, cache)
end

function scatter_kernel!(y, α, fecoef, refs, cache)
Expand Down Expand Up @@ -172,19 +181,22 @@ mutable struct FixedEffectSolverMetal{T} <: FixedEffects.AbstractFixedEffectSolv
v::FixedEffectCoefficients{<: AbstractVector{T}}
h::FixedEffectCoefficients{<: AbstractVector{T}}
hbar::FixedEffectCoefficients{<: AbstractVector{T}}
tmp::Vector{T}
fes::Vector{<:FixedEffect}
end


function FixedEffects.AbstractFixedEffectSolver{T}(fes::Vector{<:FixedEffect}, weights::AbstractWeights, ::Type{Val{:Metal}}) where {T}
T === Float32 || throw(ArgumentError("The Metal backend supports Float32 solves only; pass double_precision=false or use method=:cpu for Float64."))
m = FixedEffectLinearMapMetal{T}(fes)
b = Metal.zeros(T, length(weights); storage = Metal.SharedStorage)
r = Metal.zeros(T, length(weights); storage = Metal.SharedStorage)
b = Metal.zeros(T, length(weights))
r = Metal.zeros(T, length(weights))
x = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
v = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
h = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
hbar = FixedEffectCoefficients([Metal.zeros(T, fe.n) for fe in fes])
feM = FixedEffectSolverMetal{T}(m, Metal.zeros(T, length(weights)), b, r, x, v, h, hbar, fes)
tmp = zeros(T, length(weights))
feM = FixedEffectSolverMetal{T}(m, Metal.zeros(T, length(weights)), b, r, x, v, h, hbar, tmp, fes)
FixedEffects.update_weights!(feM, weights)
end

Expand All @@ -201,7 +213,7 @@ function FixedEffects.update_weights!(feM::FixedEffectSolverMetal{T}, weights::A
end

function scale!(scale::MtlVector, refs::MtlVector, interaction::MtlVector, weights::MtlVector)
nthreads = Int(device().maxThreadsPerThreadgroup.width)
nthreads = _metal_threadgroup_width()
nblocks = cld(length(refs), nthreads)
fill!(scale, 0)
Metal.@sync @metal threads=nthreads groups=nblocks scale_kernel!(scale, refs, interaction, weights)
Expand All @@ -224,10 +236,10 @@ function inv_kernel!(scale, T)
return nothing
end

function cache!(cache, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector)
nthreads = Int(device().maxThreadsPerThreadgroup.width)
nblocks = cld(length(cache[1]), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks cache!_kernel!(cache[1], refs, interaction, weights, scale)
function cache!(cache::MtlVector, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector)
nthreads = _metal_threadgroup_width()
nblocks = cld(length(cache), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks cache!_kernel!(cache, refs, interaction, weights, scale)
end

function cache!_kernel!(cache, refs, interaction, weights, scale)
Expand All @@ -240,14 +252,14 @@ end

function FixedEffects.copy_internal!(feM::FixedEffectSolverMetal{T}, field::Symbol, r::AbstractVector) where {T}
synchronize()
feM_r = unsafe_wrap(Array{T}, getfield(feM, field), size(getfield(feM, field)))
copyto!(feM_r, r)
copyto!(feM.tmp, r)
copyto!(getfield(feM, field), feM.tmp)
end

function FixedEffects.copy_internal!(r::AbstractVector, feM::FixedEffectSolverMetal{T}, field::Symbol) where {T}
synchronize()
feM_r = unsafe_wrap(Array{T}, getfield(feM, field), size(getfield(feM, field)))
copyto!(r, feM_r)
copyto!(feM.tmp, getfield(feM, field))
copyto!(r, feM.tmp)
end


Expand Down
19 changes: 17 additions & 2 deletions src/AbstractFixedEffectLinearMap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@

abstract type AbstractFixedEffectLinearMap{T} end

# Per-fixed-effect plan for the adjoint gather (A'u). Chosen once at construction and
# dispatched on by each backend's `gather!`. The forward map (A = scatter) needs only the
# preconditioner `caches[i]` and is shared; the gather plan lives in `gathers[i]`.
# CPU uses Serial/Threaded; the GPU backends use Atomic/Bucket.
struct SerialGather end
struct ThreadedGather{V<:AbstractVector}
buffers::Vector{V} # one length-fe.n accumulator per thread
ranges::Vector{UnitRange{Int}} # contiguous row chunks
end
struct AtomicGather end
struct BucketGather{V<:AbstractVector}
perm::V # observation indices sorted by group
offsets::V # CSR offsets into perm (length ngroups + 1)
end

Base.adjoint(fem::AbstractFixedEffectLinearMap) = Adjoint(fem)

function Base.size(fem::AbstractFixedEffectLinearMap, dim::Integer)
Expand All @@ -28,8 +43,8 @@ function LinearAlgebra.mul!(fecoefs::FixedEffectCoefficients,
y::AbstractVector, α::Number, β::Number) where {T}
fem = adjoint(Cfem)
rmul!(fecoefs, β)
for (fecoef, fe, cache) in zip(fecoefs.x, fem.fes, fem.caches)
gather!(fecoef, fe.refs, α, y, cache)
for (fecoef, fe, cache, gather) in zip(fecoefs.x, fem.fes, fem.caches, fem.gathers)
gather!(fecoef, fe.refs, α, y, cache, gather)
end
return fecoefs
end
Expand Down
30 changes: 18 additions & 12 deletions src/AbstractFixedEffectSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ Returns ``y_i - X_i'\\beta`` where ``\\beta = argmin_{b} \\sum_i y_i - X_i'b``,
* `fes`: A `Vector{<:FixedEffect}`
* `w`: A vector of weights, i.e. `AbstractWeights`
* `method` : A symbol between :cpu (default), :CUDA, or :Metal
* `double_precision::Bool`: Should the demeaning operation use Float64 rather than Float32? Default to method == : cpu.
* `double_precision::Bool`: Should the demeaning operation use Float64 rather than Float32? Default to method == :cpu. GPU backends use Float32 by default; Float32 solves use a looser default tolerance and can be less accurate than CPU Float64 solves.
* `tol` : Tolerance. Default to 1e-8 if `double_precision = true`, 1e-6 otherwise.
* `maxiter` : Maximum number of iterations
* `maxiter` : Maximum number of LSMR iterations

### Returns
* `res` : Residual of the least square problem
Expand Down Expand Up @@ -47,18 +47,22 @@ end


function solve_residuals!(r::AbstractVector{<:Real}, feM::AbstractFixedEffectSolver{T}; tol::Real = sqrt(eps(T)), maxiter::Integer = 100_000) where {T}
maxiter >= 0 || throw(ArgumentError("maxiter must be non-negative"))
# One cannot copy view of Vector (r) on GPU, so first collect the vector
copy_internal!(feM, :r, r)
if !(feM.weights isa UnitWeights)
feM.r .*= sqrt.(feM.weights)
end
copyto!(feM.b, feM.r)
mul!(feM.x, feM.m', feM.b, 1, 0)
iter, converged = 1, true
if length(feM.x.x) > 1
x, ch = lsmr!(feM.x, feM.m, feM.b, feM.v, feM.h, feM.hbar; atol = tol, btol = tol, maxiter = maxiter - 1)
iter, converged = ch.mvps + 1, ch.isconverged
end
fill!(feM.x, zero(T))
iter, converged = 0, true
if length(feM.x.x) == 1
mul!(feM.x, feM.m', feM.b, 1, 0)
else
_, ch = lsmr!(feM.x, feM.m, feM.b, feM.v, feM.h, feM.hbar; atol = tol, btol = tol, maxiter = maxiter)
iter, converged = ch.mvps, ch.isconverged
end
converged || @warn "solve_residuals! did not converge within maxiter LSMR iterations; returned values may be inaccurate." iterations=iter maxiter tol
mul!(feM.r, feM.m, feM.x, -1, 1)
if !(feM.weights isa UnitWeights)
feM.r ./= sqrt.(feM.weights)
Expand Down Expand Up @@ -111,9 +115,9 @@ Returns ``\\beta = argmin_{b} \\sum_i w_i(y_i - X_i'b)`` where `X` denotes the m
* `fes`: A `Vector{<:FixedEffect}`
* `w`: A vector of weights, i.e. `AbstractWeights`
* `method` : A symbol between :cpu (default), :CUDA, or :Metal
* `double_precision::Bool`: Should the demeaning operation use Float64 rather than Float32? Default to method == :cpu.
* `double_precision::Bool`: Should the demeaning operation use Float64 rather than Float32? Default to method == :cpu. GPU backends use Float32 by default; Float32 solves use a looser default tolerance and can be less accurate than CPU Float64 solves.
* `tol` : Tolerance. Default to 1e-8 if `double_precision = true`, 1e-6 otherwise.
* `maxiter` : Maximum number of iterations
* `maxiter` : Maximum number of LSMR iterations


### Returns
Expand Down Expand Up @@ -145,16 +149,18 @@ function solve_coefficients!(y::AbstractVector{<: Number}, fes::AbstractVector{<
end

function solve_coefficients!(r::AbstractVector, feM::AbstractFixedEffectSolver{T}; tol::Real = sqrt(eps(T)), maxiter::Integer = 100_000) where {T}
maxiter >= 0 || throw(ArgumentError("maxiter must be non-negative"))
# One cannot copy view of Vector (r) on GPU, so first collect the vector
copy_internal!(feM, :b, r)
if !(feM.weights isa UnitWeights)
feM.b .*= sqrt.(feM.weights)
end
fill!(feM.x, zero(T))
x, ch = lsmr!(feM.x, feM.m, feM.b, feM.v, feM.h, feM.hbar; atol = tol, btol = tol, maxiter = maxiter)
_, ch = lsmr!(feM.x, feM.m, feM.b, feM.v, feM.h, feM.hbar; atol = tol, btol = tol, maxiter = maxiter)
ch.isconverged || @warn "solve_coefficients! did not converge within maxiter LSMR iterations; returned values may be inaccurate." iterations=ch.mvps maxiter tol
for (x, scale) in zip(feM.x.x, feM.m.scales)
x .*= scale
end
x = Vector{eltype(r)}[collect(x) for x in feM.x.x]
full(normalize!(x, feM.m.fes), feM.m.fes), div(ch.mvps, 2), ch.isconverged
full(normalize!(x, feM.m.fes), feM.m.fes), ch.mvps, ch.isconverged
end
1 change: 1 addition & 0 deletions src/FixedEffects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module FixedEffects
##############################################################################

using Base: @propagate_inbounds
using Base.Threads: @threads, nthreads
using LinearAlgebra: LinearAlgebra, Adjoint, mul!, rmul!, norm, axpy!
using PrecompileTools: @setup_workload, @compile_workload
using StatsBase: AbstractWeights, UnitWeights, Weights, uweights
Expand Down
Loading
Loading