Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "oneAPI"
uuid = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
version = "2.6.0"
authors = ["Tim Besard <tim.besard@gmail.com>", "Alexis Montoison", "Michel Schanen <michel.schanen@gmail.com>"]
version = "2.6.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expand All @@ -29,22 +30,16 @@ oneAPI_Level_Zero_Headers_jll = "f4bc562b-d309-54f8-9efb-476e56f0410d"
oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"

[weakdeps]
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"

[extensions]
oneAPIAcceleratedKernelsExt = "AcceleratedKernels"

[compat]
AbstractFFTs = "1.5.0"
AcceleratedKernels = "0.4.3"
AcceleratedKernels = "0.3.1, 0.4"
Adapt = "4"
CEnum = "0.4, 0.5"
ExprTools = "0.1"
GPUArrays = "11.2.1"
GPUCompiler = "1.6"
GPUToolbox = "0.1, 0.2, 0.3, 1"
KernelAbstractions = "0.9.1"
KernelAbstractions = "0.9.39"
LLVM = "6, 7, 8, 9"
NEO_jll = "=25.44.36015"
Preferences = "1"
Expand Down
2 changes: 2 additions & 0 deletions lib/level-zero/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ Base.length(iter::ZeDevices) = length(iter.handles)

Base.IteratorSize(::ZeDevices) = Base.HasLength()

Base.keys(iter::ZeDevices) = 1:length(iter)

function Base.show(io::IO, ::MIME"text/plain", iter::ZeDevices)
print(io, "ZeDevice iterator for $(length(iter)) devices")
if !isempty(iter)
Expand Down
4 changes: 0 additions & 4 deletions ext/oneAPIAcceleratedKernelsExt.jl → src/accumulate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
module oneAPIAcceleratedKernelsExt

import oneAPI
import oneAPI: oneArray, oneAPIBackend
import AcceleratedKernels as AK
Expand All @@ -13,5 +11,3 @@ Base.accumulate(op, A::oneArray; init = zero(eltype(A)), kwargs...) =

Base.cumsum(src::oneArray; kwargs...) = AK.cumsum(src, oneAPIBackend(); kwargs...)
Base.cumprod(src::oneArray; kwargs...) = AK.cumprod(src, oneAPIBackend(); kwargs...)

end # module
8 changes: 7 additions & 1 deletion src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ See also: [`device`](@ref), [`devices`](@ref)
function device!(drv::ZeDevice)
task_local_storage(:ZeDevice, drv)
end
device!(i::Int) = device!(devices(driver())[i])
function device!(i::Int)
devs = devices(driver())
if i < 1 || i > length(devs)
throw(ArgumentError("Invalid device index $i (must be between 1 and $(length(devs)))"))
end
return device!(devs[i])
end

const global_contexts = Dict{ZeDriver,ZeContext}()

Expand Down
7 changes: 7 additions & 0 deletions src/device/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ end
end
end

@device_function @inline function unsafe_cached_load(ptr::LLVMPtr{T, A}, i::Integer, align::Val) where {T, A}
# For SPIR-V/Level Zero, we don't have explicit cache control intrinsics like CUDA's __ldg
# So we fall back to a regular unsafe_load. The SPIR-V compiler may still apply
# appropriate optimizations based on context.
unsafe_load(ptr, i, align)
end

@device_function @inline function const_arrayref(A::oneDeviceArray{T}, index::Integer) where {T}
# simplified bounds check (see `arrayset`)
#@boundscheck checkbounds(A, index)
Expand Down
1 change: 1 addition & 0 deletions src/oneAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ include("utils.jl")

include("oneAPIKernels.jl")
import .oneAPIKernels: oneAPIBackend
include("accumulate.jl")
include("indexing.jl")
export oneAPIBackend

Expand Down
102 changes: 96 additions & 6 deletions src/oneAPIKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@ import Adapt
export oneAPIBackend

struct oneAPIBackend <: KA.GPU
prefer_blocks::Bool
always_inline::Bool
end

KA.allocate(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneArray{T}(undef, dims)
KA.zeros(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.zeros(T, dims)
KA.ones(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.ones(T, dims)
oneAPIBackend(; prefer_blocks = false, always_inline = false) = oneAPIBackend(prefer_blocks, always_inline)

@inline KA.allocate(::oneAPIBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where {T} = oneArray{T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer}(undef, dims)
@inline KA.zeros(::oneAPIBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where {T} = fill!(oneArray{T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer}(undef, dims), zero(T))
@inline KA.ones(::oneAPIBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where {T} = fill!(oneArray{T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer}(undef, dims), one(T))

KA.get_backend(::oneArray) = oneAPIBackend()
# TODO should be non-blocking
KA.synchronize(::oneAPIBackend) = oneL0.synchronize()
KA.synchronize(::oneAPIBackend) = oneAPI.oneL0.synchronize()
KA.supports_float64(::oneAPIBackend) = false # TODO: Check if this is device dependent
KA.supports_unified(::oneAPIBackend) = true

KA.functional(::oneAPIBackend) = oneAPI.functional()

Adapt.adapt_storage(::oneAPIBackend, a::Array) = Adapt.adapt(oneArray, a)
Adapt.adapt_storage(::oneAPIBackend, a::AbstractArray) = Adapt.adapt(oneArray, a)
Adapt.adapt_storage(::oneAPIBackend, a::oneArray) = a
Adapt.adapt_storage(::KA.CPU, a::oneArray) = convert(Array, a)

Expand All @@ -39,6 +46,24 @@ function KA.copyto!(::oneAPIBackend, A, B)
end


## Device Operations

function KA.ndevices(::oneAPIBackend)
return length(oneAPI.devices())
end

function KA.device(::oneAPIBackend)::Int
dev = oneAPI.device()
devs = oneAPI.devices()
idx = findfirst(==(dev), devs)
return idx === nothing ? 1 : idx
end

function KA.device!(backend::oneAPIBackend, id::Int)
return oneAPI.device!(id)
end


## Kernel Launch

function KA.mkcontext(kernel::KA.Kernel{oneAPIBackend}, _ndrange, iterspace)
Expand Down Expand Up @@ -83,14 +108,42 @@ function threads_to_workgroupsize(threads, ndrange)
end

function (obj::KA.Kernel{oneAPIBackend})(args...; ndrange=nothing, workgroupsize=nothing)
backend = KA.backend(obj)

ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange, workgroupsize)
# this might not be the final context, since we may tune the workgroupsize
ctx = KA.mkcontext(obj, ndrange, iterspace)
kernel = @oneapi launch=false obj.f(ctx, args...)

# If the kernel is statically sized we can tell the compiler about that
if KA.workgroupsize(obj) <: KA.StaticSize
# TODO: maxthreads
# maxthreads = prod(KA.get(KA.workgroupsize(obj)))
else
# maxthreads = nothing
end

kernel = @oneapi launch = false always_inline = backend.always_inline obj.f(ctx, args...)

# figure out the optimal workgroupsize automatically
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
items = oneAPI.launch_configuration(kernel)

if backend.prefer_blocks
# Prefer blocks over threads:
# Reducing the workgroup size (items) increases the number of workgroups (blocks).
# We use a simple heuristic here since we lack full occupancy info (max_blocks) from launch_configuration.

# If the total range is large enough, full workgroups are fine.
# If the range is small, we might want to reduce 'items' to create more blocks to fill the GPU.
# (Simplified logic compared to CUDA.jl which uses explicit occupancy calculators)
total_items = prod(ndrange)
if total_items < items * 16 # Heuristic factor
# Force at least a few blocks if possible by reducing items per block
target_blocks = 16 # Target at least 16 blocks
items = max(1, min(items, cld(total_items, target_blocks)))
end
end

workgroupsize = threads_to_workgroupsize(items, ndrange)
iterspace, dynamic = KA.partition(obj, ndrange, workgroupsize)
ctx = KA.mkcontext(obj, ndrange, iterspace)
Expand Down Expand Up @@ -171,6 +224,43 @@ end

## Other

Adapt.adapt_storage(to::KA.ConstAdaptor, a::oneDeviceArray) = Base.Experimental.Const(a)

KA.argconvert(::KA.Kernel{oneAPIBackend}, arg) = kernel_convert(arg)

function KA.priority!(::oneAPIBackend, prio::Symbol)
if !(prio in (:high, :normal, :low))
error("priority must be one of :high, :normal, :low")
end

priority_enum = if prio == :high
oneAPI.oneL0.ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_HIGH
elseif prio == :low
oneAPI.oneL0.ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_LOW
else
oneAPI.oneL0.ZE_COMMAND_QUEUE_PRIORITY_NORMAL
end

ctx = oneAPI.context()
dev = oneAPI.device()

# Update the cached queue
# We synchronize the current queue first to ensure safety
current_queue = oneAPI.global_queue(ctx, dev)
oneAPI.oneL0.synchronize(current_queue)

# Replace the queue in task_local_storage
# The key used by global_queue is (:ZeCommandQueue, ctx, dev)

new_queue = oneAPI.oneL0.ZeCommandQueue(
ctx, dev;
flags = oneAPI.oneL0.ZE_COMMAND_QUEUE_FLAG_IN_ORDER,
priority = priority_enum
)

task_local_storage((:ZeCommandQueue, ctx, dev), new_queue)

return nothing
end

end
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down
2 changes: 1 addition & 1 deletion test/setup.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Distributed, Test, oneAPI, AcceleratedKernels
using Distributed, Test, oneAPI

oneAPI.functional() || error("oneAPI.jl is not functional on this system")

Expand Down
Loading