Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
7a910a3
feat(where): Add IL-generated SIMD optimization for np.where(conditio…
Nucs Apr 12, 2026
c335e0a
perf(where): AVX2/SSE4.1 optimize mask expansion in np.where kernel
Nucs Apr 12, 2026
753d753
perf(where): inline mask creation in IL - 5.4x faster kernel
Nucs Apr 12, 2026
25859e5
fix(where): implement NumPy 2.x NEP50 type promotion for np.where
Nucs Apr 15, 2026
653af58
feat(asanyarray): support all built-in C# collection types
Nucs Apr 15, 2026
974e70d
feat(asanyarray): add non-generic IEnumerable/IEnumerator fallback
Nucs Apr 15, 2026
23ad1c1
refactor(asanyarray): consolidate duplicate code
Nucs Apr 15, 2026
06a43c2
fix(asanyarray): add Tuple/ValueTuple support, fix empty collection h…
Nucs Apr 15, 2026
3d3af19
fix(asanyarray): add NumPy-like type promotion for mixed-type collect…
Nucs Apr 15, 2026
a3205e9
perf(asanyarray): optimize non-generic collection conversion ~4x faster
Nucs Apr 15, 2026
6b0f147
perf(asanyarray): add ToArrayFast with CollectionsMarshal/CopyTo opti…
Nucs Apr 15, 2026
dd1ae2a
perf(asanyarray): use GC.AllocateUninitializedArray to skip zeroing
Nucs Apr 15, 2026
44dd163
perf(asanyarray): optimize with CollectionsMarshal.AsSpan and early exit
Nucs Apr 15, 2026
4e8af8b
refactor(tests): migrate np.where and np.asanyarray tests from TUnit …
Nucs Apr 20, 2026
ae4f1b8
fix(asanyarray): handle object[] via type-promotion path
Nucs Apr 20, 2026
f0473d2
cleanup(where, asanyarray): remove dead code and trim noisy comments
Nucs Apr 20, 2026
3811960
fix(asanyarray,where): pure-float object[] promotion + hot-path short…
Nucs Apr 20, 2026
21d7eec
refactor(where): consolidate reflection cache into partial CachedMethods
Nucs Apr 20, 2026
a5862bd
fix(where): gate x86-specific SIMD path on Sse41/Avx2 for ARM64 compa…
Nucs Apr 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions src/NumSharp.Core/APIs/np.where.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
using System;
using NumSharp.Backends.Kernels;
using NumSharp.Generic;

namespace NumSharp
{
public static partial class np
{
/// <summary>
/// Equivalent to <see cref="nonzero(NDArray)"/>: returns the indices where
/// <paramref name="condition"/> is non-zero.
/// </summary>
/// <param name="condition">Input array. Non-zero entries yield their indices.</param>
/// <returns>Tuple of arrays with indices where condition is non-zero, one per dimension.</returns>
/// <remarks>https://numpy.org/doc/stable/reference/generated/numpy.where.html</remarks>
public static NDArray<long>[] where(NDArray condition)
{
return nonzero(condition);
}

/// <summary>
/// Return elements chosen from `x` or `y` depending on `condition`.
/// </summary>
/// <param name="condition">Where True, yield `x`, otherwise yield `y`.</param>
/// <param name="x">Values from which to choose where condition is True.</param>
/// <param name="y">Values from which to choose where condition is False.</param>
/// <returns>An array with elements from `x` where `condition` is True, and elements from `y` elsewhere.</returns>
/// <remarks>https://numpy.org/doc/stable/reference/generated/numpy.where.html</remarks>
public static NDArray where(NDArray condition, NDArray x, NDArray y)
{
return where_internal(condition, x, y);
}

/// <summary>
/// Return elements chosen from `x` or `y` depending on `condition`.
/// Scalar overload for x.
/// </summary>
public static NDArray where(NDArray condition, object x, NDArray y)
{
return where_internal(condition, asanyarray(x), y);
}

/// <summary>
/// Return elements chosen from `x` or `y` depending on `condition`.
/// Scalar overload for y.
/// </summary>
public static NDArray where(NDArray condition, NDArray x, object y)
{
return where_internal(condition, x, asanyarray(y));
}

/// <summary>
/// Return elements chosen from `x` or `y` depending on `condition`.
/// Scalar overload for both x and y.
/// </summary>
public static NDArray where(NDArray condition, object x, object y)
{
return where_internal(condition, asanyarray(x), asanyarray(y));
}

/// <summary>
/// Internal implementation of np.where.
/// </summary>
private static NDArray where_internal(NDArray condition, NDArray x, NDArray y)
{
// Skip broadcast_arrays (which allocates 3 NDArrays + helper arrays) when all three
// already share a shape — the frequent case of np.where(mask, arr, other_arr).
NDArray cond, xArr, yArr;
if (condition.Shape == x.Shape && x.Shape == y.Shape)
{
cond = condition;
xArr = x;
yArr = y;
}
else
{
var broadcasted = broadcast_arrays(condition, x, y);
cond = broadcasted[0];
xArr = broadcasted[1];
yArr = broadcasted[2];
}

// When x and y already agree, skip the NEP50 promotion lookup. Otherwise defer to
// _FindCommonType which handles the scalar+array NEP50 rules.
var outType = x.GetTypeCode == y.GetTypeCode
? x.GetTypeCode
: _FindCommonType(x, y);

if (xArr.GetTypeCode != outType)
xArr = xArr.astype(outType, copy: false);
if (yArr.GetTypeCode != outType)
yArr = yArr.astype(outType, copy: false);

// Use cond.shape (dimensions only) not cond.Shape (which may have broadcast strides)
var result = empty(cond.shape, outType);

// Handle empty arrays - nothing to iterate
if (result.size == 0)
return result;

// IL Kernel fast path: all arrays contiguous, bool condition, SIMD enabled
// Broadcasted arrays (stride=0) are NOT contiguous, so they use iterator path.
bool canUseKernel = ILKernelGenerator.Enabled &&
cond.typecode == NPTypeCode.Boolean &&
cond.Shape.IsContiguous &&
xArr.Shape.IsContiguous &&
yArr.Shape.IsContiguous;

if (canUseKernel)
{
WhereKernelDispatch(cond, xArr, yArr, result, outType);
return result;
}

// Iterator fallback for non-contiguous/broadcasted arrays
switch (outType)
{
case NPTypeCode.Boolean:
WhereImpl<bool>(cond, xArr, yArr, result);
break;
case NPTypeCode.Byte:
WhereImpl<byte>(cond, xArr, yArr, result);
break;
case NPTypeCode.Int16:
WhereImpl<short>(cond, xArr, yArr, result);
break;
case NPTypeCode.UInt16:
WhereImpl<ushort>(cond, xArr, yArr, result);
break;
case NPTypeCode.Int32:
WhereImpl<int>(cond, xArr, yArr, result);
break;
case NPTypeCode.UInt32:
WhereImpl<uint>(cond, xArr, yArr, result);
break;
case NPTypeCode.Int64:
WhereImpl<long>(cond, xArr, yArr, result);
break;
case NPTypeCode.UInt64:
WhereImpl<ulong>(cond, xArr, yArr, result);
break;
case NPTypeCode.Char:
WhereImpl<char>(cond, xArr, yArr, result);
break;
case NPTypeCode.Single:
WhereImpl<float>(cond, xArr, yArr, result);
break;
case NPTypeCode.Double:
WhereImpl<double>(cond, xArr, yArr, result);
break;
case NPTypeCode.Decimal:
WhereImpl<decimal>(cond, xArr, yArr, result);
break;
default:
throw new NotSupportedException($"Type {outType} not supported for np.where");
}

return result;
}

private static void WhereImpl<T>(NDArray cond, NDArray x, NDArray y, NDArray result) where T : unmanaged
{
// Use iterators for proper handling of broadcasted/strided arrays
using var condIter = cond.AsIterator<bool>();
using var xIter = x.AsIterator<T>();
using var yIter = y.AsIterator<T>();
using var resultIter = result.AsIterator<T>();

while (condIter.HasNext())
{
var c = condIter.MoveNext();
var xVal = xIter.MoveNext();
var yVal = yIter.MoveNext();
resultIter.MoveNextReference() = c ? xVal : yVal;
}
}

/// <summary>
/// IL Kernel dispatch for contiguous arrays.
/// Uses IL-generated kernels with SIMD optimization.
/// </summary>
private static unsafe void WhereKernelDispatch(NDArray cond, NDArray x, NDArray y, NDArray result, NPTypeCode outType)
{
var condPtr = (bool*)cond.Address;
var count = result.size;

switch (outType)
{
case NPTypeCode.Boolean:
ILKernelGenerator.WhereExecute(condPtr, (bool*)x.Address, (bool*)y.Address, (bool*)result.Address, count);
break;
case NPTypeCode.Byte:
ILKernelGenerator.WhereExecute(condPtr, (byte*)x.Address, (byte*)y.Address, (byte*)result.Address, count);
break;
case NPTypeCode.Int16:
ILKernelGenerator.WhereExecute(condPtr, (short*)x.Address, (short*)y.Address, (short*)result.Address, count);
break;
case NPTypeCode.UInt16:
ILKernelGenerator.WhereExecute(condPtr, (ushort*)x.Address, (ushort*)y.Address, (ushort*)result.Address, count);
break;
case NPTypeCode.Int32:
ILKernelGenerator.WhereExecute(condPtr, (int*)x.Address, (int*)y.Address, (int*)result.Address, count);
break;
case NPTypeCode.UInt32:
ILKernelGenerator.WhereExecute(condPtr, (uint*)x.Address, (uint*)y.Address, (uint*)result.Address, count);
break;
case NPTypeCode.Int64:
ILKernelGenerator.WhereExecute(condPtr, (long*)x.Address, (long*)y.Address, (long*)result.Address, count);
break;
case NPTypeCode.UInt64:
ILKernelGenerator.WhereExecute(condPtr, (ulong*)x.Address, (ulong*)y.Address, (ulong*)result.Address, count);
break;
case NPTypeCode.Char:
ILKernelGenerator.WhereExecute(condPtr, (char*)x.Address, (char*)y.Address, (char*)result.Address, count);
break;
case NPTypeCode.Single:
ILKernelGenerator.WhereExecute(condPtr, (float*)x.Address, (float*)y.Address, (float*)result.Address, count);
break;
case NPTypeCode.Double:
ILKernelGenerator.WhereExecute(condPtr, (double*)x.Address, (double*)y.Address, (double*)result.Address, count);
break;
case NPTypeCode.Decimal:
ILKernelGenerator.WhereExecute(condPtr, (decimal*)x.Address, (decimal*)y.Address, (decimal*)result.Address, count);
break;
default:
throw new NotSupportedException($"Type {outType} not supported for np.where");
}
}
}
}
Loading
Loading