diff --git a/src/NumSharp.Core/APIs/np.where.cs b/src/NumSharp.Core/APIs/np.where.cs new file mode 100644 index 00000000..14633e3c --- /dev/null +++ b/src/NumSharp.Core/APIs/np.where.cs @@ -0,0 +1,230 @@ +using System; +using NumSharp.Backends.Kernels; +using NumSharp.Generic; + +namespace NumSharp +{ + public static partial class np + { + /// + /// Equivalent to : returns the indices where + /// is non-zero. + /// + /// Input array. Non-zero entries yield their indices. + /// Tuple of arrays with indices where condition is non-zero, one per dimension. + /// https://numpy.org/doc/stable/reference/generated/numpy.where.html + public static NDArray[] where(NDArray condition) + { + return nonzero(condition); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// + /// Where True, yield `x`, otherwise yield `y`. + /// Values from which to choose where condition is True. + /// Values from which to choose where condition is False. + /// An array with elements from `x` where `condition` is True, and elements from `y` elsewhere. + /// https://numpy.org/doc/stable/reference/generated/numpy.where.html + public static NDArray where(NDArray condition, NDArray x, NDArray y) + { + return where_internal(condition, x, y); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for x. + /// + public static NDArray where(NDArray condition, object x, NDArray y) + { + return where_internal(condition, asanyarray(x), y); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for y. + /// + public static NDArray where(NDArray condition, NDArray x, object y) + { + return where_internal(condition, x, asanyarray(y)); + } + + /// + /// Return elements chosen from `x` or `y` depending on `condition`. + /// Scalar overload for both x and y. + /// + public static NDArray where(NDArray condition, object x, object y) + { + return where_internal(condition, asanyarray(x), asanyarray(y)); + } + + /// + /// Internal implementation of np.where. + /// + private static NDArray where_internal(NDArray condition, NDArray x, NDArray y) + { + // Skip broadcast_arrays (which allocates 3 NDArrays + helper arrays) when all three + // already share a shape — the frequent case of np.where(mask, arr, other_arr). + NDArray cond, xArr, yArr; + if (condition.Shape == x.Shape && x.Shape == y.Shape) + { + cond = condition; + xArr = x; + yArr = y; + } + else + { + var broadcasted = broadcast_arrays(condition, x, y); + cond = broadcasted[0]; + xArr = broadcasted[1]; + yArr = broadcasted[2]; + } + + // When x and y already agree, skip the NEP50 promotion lookup. Otherwise defer to + // _FindCommonType which handles the scalar+array NEP50 rules. + var outType = x.GetTypeCode == y.GetTypeCode + ? x.GetTypeCode + : _FindCommonType(x, y); + + if (xArr.GetTypeCode != outType) + xArr = xArr.astype(outType, copy: false); + if (yArr.GetTypeCode != outType) + yArr = yArr.astype(outType, copy: false); + + // Use cond.shape (dimensions only) not cond.Shape (which may have broadcast strides) + var result = empty(cond.shape, outType); + + // Handle empty arrays - nothing to iterate + if (result.size == 0) + return result; + + // IL Kernel fast path: all arrays contiguous, bool condition, SIMD enabled + // Broadcasted arrays (stride=0) are NOT contiguous, so they use iterator path. + bool canUseKernel = ILKernelGenerator.Enabled && + cond.typecode == NPTypeCode.Boolean && + cond.Shape.IsContiguous && + xArr.Shape.IsContiguous && + yArr.Shape.IsContiguous; + + if (canUseKernel) + { + WhereKernelDispatch(cond, xArr, yArr, result, outType); + return result; + } + + // Iterator fallback for non-contiguous/broadcasted arrays + switch (outType) + { + case NPTypeCode.Boolean: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Byte: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Int16: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.UInt16: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Int32: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.UInt32: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Int64: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.UInt64: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Char: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Single: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Double: + WhereImpl(cond, xArr, yArr, result); + break; + case NPTypeCode.Decimal: + WhereImpl(cond, xArr, yArr, result); + break; + default: + throw new NotSupportedException($"Type {outType} not supported for np.where"); + } + + return result; + } + + private static void WhereImpl(NDArray cond, NDArray x, NDArray y, NDArray result) where T : unmanaged + { + // Use iterators for proper handling of broadcasted/strided arrays + using var condIter = cond.AsIterator(); + using var xIter = x.AsIterator(); + using var yIter = y.AsIterator(); + using var resultIter = result.AsIterator(); + + while (condIter.HasNext()) + { + var c = condIter.MoveNext(); + var xVal = xIter.MoveNext(); + var yVal = yIter.MoveNext(); + resultIter.MoveNextReference() = c ? xVal : yVal; + } + } + + /// + /// IL Kernel dispatch for contiguous arrays. + /// Uses IL-generated kernels with SIMD optimization. + /// + private static unsafe void WhereKernelDispatch(NDArray cond, NDArray x, NDArray y, NDArray result, NPTypeCode outType) + { + var condPtr = (bool*)cond.Address; + var count = result.size; + + switch (outType) + { + case NPTypeCode.Boolean: + ILKernelGenerator.WhereExecute(condPtr, (bool*)x.Address, (bool*)y.Address, (bool*)result.Address, count); + break; + case NPTypeCode.Byte: + ILKernelGenerator.WhereExecute(condPtr, (byte*)x.Address, (byte*)y.Address, (byte*)result.Address, count); + break; + case NPTypeCode.Int16: + ILKernelGenerator.WhereExecute(condPtr, (short*)x.Address, (short*)y.Address, (short*)result.Address, count); + break; + case NPTypeCode.UInt16: + ILKernelGenerator.WhereExecute(condPtr, (ushort*)x.Address, (ushort*)y.Address, (ushort*)result.Address, count); + break; + case NPTypeCode.Int32: + ILKernelGenerator.WhereExecute(condPtr, (int*)x.Address, (int*)y.Address, (int*)result.Address, count); + break; + case NPTypeCode.UInt32: + ILKernelGenerator.WhereExecute(condPtr, (uint*)x.Address, (uint*)y.Address, (uint*)result.Address, count); + break; + case NPTypeCode.Int64: + ILKernelGenerator.WhereExecute(condPtr, (long*)x.Address, (long*)y.Address, (long*)result.Address, count); + break; + case NPTypeCode.UInt64: + ILKernelGenerator.WhereExecute(condPtr, (ulong*)x.Address, (ulong*)y.Address, (ulong*)result.Address, count); + break; + case NPTypeCode.Char: + ILKernelGenerator.WhereExecute(condPtr, (char*)x.Address, (char*)y.Address, (char*)result.Address, count); + break; + case NPTypeCode.Single: + ILKernelGenerator.WhereExecute(condPtr, (float*)x.Address, (float*)y.Address, (float*)result.Address, count); + break; + case NPTypeCode.Double: + ILKernelGenerator.WhereExecute(condPtr, (double*)x.Address, (double*)y.Address, (double*)result.Address, count); + break; + case NPTypeCode.Decimal: + ILKernelGenerator.WhereExecute(condPtr, (decimal*)x.Address, (decimal*)y.Address, (decimal*)result.Address, count); + break; + default: + throw new NotSupportedException($"Type {outType} not supported for np.where"); + } + } + } +} diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs new file mode 100644 index 00000000..72678ca7 --- /dev/null +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs @@ -0,0 +1,699 @@ +using System; +using System.Collections.Concurrent; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +using NumSharp.Utilities; + +// ============================================================================= +// ILKernelGenerator.Where - IL-generated np.where(condition, x, y) kernels +// ============================================================================= +// +// RESPONSIBILITY: +// - Generate optimized kernels for conditional selection +// - result[i] = cond[i] ? x[i] : y[i] +// +// ARCHITECTURE: +// Uses IL emission to generate type-specific kernels at runtime. +// The challenge is bool mask expansion: condition is bool[] (1 byte per element), +// but x/y can be any dtype (1-8 bytes per element). +// +// | Element Size | V256 Elements | Bools to Load | +// |--------------|---------------|---------------| +// | 1 byte | 32 | 32 | +// | 2 bytes | 16 | 16 | +// | 4 bytes | 8 | 8 | +// | 8 bytes | 4 | 4 | +// +// KERNEL TYPES: +// - WhereKernel: Main kernel delegate (cond*, x*, y*, result*, count) +// +// ============================================================================= + +namespace NumSharp.Backends.Kernels +{ + /// + /// Delegate for where operation kernels. + /// + public unsafe delegate void WhereKernel(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged; + + public static partial class ILKernelGenerator + { + /// + /// Cache of IL-generated where kernels. + /// Key: Type + /// + private static readonly ConcurrentDictionary _whereKernelCache = new(); + + #region Public API + + /// + /// Get or generate an IL-based where kernel for the specified type. + /// Returns null if IL generation is disabled or fails. + /// + public static WhereKernel? GetWhereKernel() where T : unmanaged + { + if (!Enabled) + return null; + + var type = typeof(T); + + if (_whereKernelCache.TryGetValue(type, out var cached)) + return (WhereKernel)cached; + + var kernel = TryGenerateWhereKernel(); + if (kernel == null) + return null; + + if (_whereKernelCache.TryAdd(type, kernel)) + return kernel; + + return (WhereKernel)_whereKernelCache[type]; + } + + /// + /// Execute where operation using IL-generated kernel or fallback to static helper. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe void WhereExecute(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged + { + if (count == 0) + return; + + var kernel = GetWhereKernel(); + if (kernel != null) + { + kernel(cond, x, y, result, count); + } + else + { + // Fallback to scalar loop + WhereScalar(cond, x, y, result, count); + } + } + + #endregion + + #region Kernel Generation + + private static WhereKernel? TryGenerateWhereKernel() where T : unmanaged + { + try + { + return GenerateWhereKernelIL(); + } + catch (Exception ex) + { + System.Diagnostics.Debug.WriteLine($"[ILKernel] TryGenerateWhereKernel<{typeof(T).Name}>: {ex.GetType().Name}: {ex.Message}"); + return null; + } + } + + private static unsafe WhereKernel GenerateWhereKernelIL() where T : unmanaged + { + int elementSize = Unsafe.SizeOf(); + + // SIMD eligibility: + // - 1-byte types (byte) only touch portable Vector128/Vector256 APIs, so they work + // on any SIMD-capable platform (including ARM64/Neon). + // - 2/4/8-byte types need Sse41.ConvertToVector128Int* (V128 path) or + // Avx2.ConvertToVector256Int* (V256 path) to expand the bool-mask lanes. + // These x86 intrinsics throw PlatformNotSupportedException on ARM64. + bool canSimdDtype = elementSize <= 8 && IsSimdSupported(); + bool needsX86 = elementSize > 1; + bool useV256 = VectorBits >= 256 && (!needsX86 || Avx2.IsSupported); + bool useV128 = !useV256 && VectorBits >= 128 && (!needsX86 || Sse41.IsSupported); + bool emitSimd = canSimdDtype && (useV256 || useV128); + + var dm = new DynamicMethod( + name: $"IL_Where_{typeof(T).Name}", + returnType: typeof(void), + parameterTypes: new[] { typeof(bool*), typeof(T*), typeof(T*), typeof(T*), typeof(long) }, + owner: typeof(ILKernelGenerator), + skipVisibility: true + ); + + var il = dm.GetILGenerator(); + + // Locals + var locI = il.DeclareLocal(typeof(long)); // loop counter + + // Labels + var lblScalarLoop = il.DefineLabel(); + var lblScalarLoopEnd = il.DefineLabel(); + + // i = 0 + il.Emit(OpCodes.Ldc_I8, 0L); + il.Emit(OpCodes.Stloc, locI); + + if (emitSimd) + { + EmitWhereSIMDLoop(il, locI, useV256); + } + + // Scalar loop for remainder + il.MarkLabel(lblScalarLoop); + + // if (i >= count) goto end + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldarg, 4); // count + il.Emit(OpCodes.Bge, lblScalarLoopEnd); + + // result[i] = cond[i] ? x[i] : y[i] + EmitWhereScalarElement(il, locI); + + // i++ + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, 1L); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Stloc, locI); + + il.Emit(OpCodes.Br, lblScalarLoop); + + il.MarkLabel(lblScalarLoopEnd); + il.Emit(OpCodes.Ret); + + return (WhereKernel)dm.CreateDelegate(typeof(WhereKernel)); + } + + private static void EmitWhereSIMDLoop(ILGenerator il, LocalBuilder locI, bool useV256) where T : unmanaged + { + long elementSize = Unsafe.SizeOf(); + long vectorCount = useV256 ? (32 / elementSize) : (16 / elementSize); + long unrollFactor = 4; + long unrollStep = vectorCount * unrollFactor; + + var locUnrollEnd = il.DeclareLocal(typeof(long)); + var locVectorEnd = il.DeclareLocal(typeof(long)); + + var lblUnrollLoop = il.DefineLabel(); + var lblUnrollLoopEnd = il.DefineLabel(); + var lblVectorLoop = il.DefineLabel(); + var lblVectorLoopEnd = il.DefineLabel(); + + // unrollEnd = count - unrollStep (for 4x unrolled loop) + il.Emit(OpCodes.Ldarg, 4); // count + il.Emit(OpCodes.Ldc_I8, unrollStep); + il.Emit(OpCodes.Sub); + il.Emit(OpCodes.Stloc, locUnrollEnd); + + // vectorEnd = count - vectorCount (for remainder loop) + il.Emit(OpCodes.Ldarg, 4); // count + il.Emit(OpCodes.Ldc_I8, vectorCount); + il.Emit(OpCodes.Sub); + il.Emit(OpCodes.Stloc, locVectorEnd); + + // ========== 4x UNROLLED SIMD LOOP ========== + il.MarkLabel(lblUnrollLoop); + + // if (i > unrollEnd) goto UnrollLoopEnd + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldloc, locUnrollEnd); + il.Emit(OpCodes.Bgt, lblUnrollLoopEnd); + + // Process 4 vectors per iteration + for (long u = 0; u < unrollFactor; u++) + { + long offset = vectorCount * u; + if (useV256) + EmitWhereV256BodyWithOffset(il, locI, elementSize, offset); + else + EmitWhereV128BodyWithOffset(il, locI, elementSize, offset); + } + + // i += unrollStep + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, unrollStep); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Stloc, locI); + + il.Emit(OpCodes.Br, lblUnrollLoop); + + il.MarkLabel(lblUnrollLoopEnd); + + // ========== REMAINDER SIMD LOOP (1 vector at a time) ========== + il.MarkLabel(lblVectorLoop); + + // if (i > vectorEnd) goto VectorLoopEnd + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldloc, locVectorEnd); + il.Emit(OpCodes.Bgt, lblVectorLoopEnd); + + // Process 1 vector + if (useV256) + EmitWhereV256BodyWithOffset(il, locI, elementSize, 0L); + else + EmitWhereV128BodyWithOffset(il, locI, elementSize, 0L); + + // i += vectorCount + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, vectorCount); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Stloc, locI); + + il.Emit(OpCodes.Br, lblVectorLoop); + + il.MarkLabel(lblVectorLoopEnd); + } + + private static void EmitWhereV256BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged + { + var loadMethod = CachedMethods.V256LoadGeneric.MakeGenericMethod(typeof(T)); + var storeMethod = CachedMethods.V256StoreGeneric.MakeGenericMethod(typeof(T)); + var selectMethod = CachedMethods.V256ConditionalSelectGeneric.MakeGenericMethod(typeof(T)); + + // Load address: cond + (i + offset) + il.Emit(OpCodes.Ldarg_0); // cond + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + + // Inline mask creation - emit AVX2 instructions directly instead of calling helper + EmitInlineMaskCreationV256(il, (int)elementSize); + + // Load x vector: x + (i + offset) * elementSize + il.Emit(OpCodes.Ldarg_1); // x + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // Load y vector: y + (i + offset) * elementSize + il.Emit(OpCodes.Ldarg_2); // y + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // Stack: mask, xVec, yVec + // ConditionalSelect(mask, x, y) + il.Emit(OpCodes.Call, selectMethod); + + // Store result: result + (i + offset) * elementSize + il.Emit(OpCodes.Ldarg_3); // result + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, storeMethod); + } + + private static void EmitWhereV128BodyWithOffset(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged + { + var loadMethod = CachedMethods.V128LoadGeneric.MakeGenericMethod(typeof(T)); + var storeMethod = CachedMethods.V128StoreGeneric.MakeGenericMethod(typeof(T)); + var selectMethod = CachedMethods.V128ConditionalSelectGeneric.MakeGenericMethod(typeof(T)); + + // Load address: cond + (i + offset) + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + + // Inline mask creation - emit SSE4.1 instructions directly + EmitInlineMaskCreationV128(il, (int)elementSize); + + // Load x vector + il.Emit(OpCodes.Ldarg_1); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // Load y vector + il.Emit(OpCodes.Ldarg_2); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, loadMethod); + + // ConditionalSelect + il.Emit(OpCodes.Call, selectMethod); + + // Store + il.Emit(OpCodes.Ldarg_3); + il.Emit(OpCodes.Ldloc, locI); + if (offset > 0) + { + il.Emit(OpCodes.Ldc_I8, offset); + il.Emit(OpCodes.Add); + } + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Call, storeMethod); + } + + private static void EmitWhereScalarElement(ILGenerator il, LocalBuilder locI) where T : unmanaged + { + long elementSize = Unsafe.SizeOf(); + var typeCode = InfoOf.NPTypeCode; + + // result[i] = cond[i] ? x[i] : y[i] + var lblFalse = il.DefineLabel(); + var lblEnd = il.DefineLabel(); + + // Load result address: result + i * elementSize + il.Emit(OpCodes.Ldarg_3); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + + // Load cond[i]: cond + i (bool is 1 byte) + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + il.Emit(OpCodes.Ldind_U1); // Load bool as byte + + // if (!cond[i]) goto lblFalse + il.Emit(OpCodes.Brfalse, lblFalse); + + // True branch: load x[i] + il.Emit(OpCodes.Ldarg_1); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + EmitLoadIndirect(il, typeCode); + il.Emit(OpCodes.Br, lblEnd); + + // False branch: load y[i] + il.MarkLabel(lblFalse); + il.Emit(OpCodes.Ldarg_2); + il.Emit(OpCodes.Ldloc, locI); + il.Emit(OpCodes.Ldc_I8, elementSize); + il.Emit(OpCodes.Mul); + il.Emit(OpCodes.Conv_I); + il.Emit(OpCodes.Add); + EmitLoadIndirect(il, typeCode); + + il.MarkLabel(lblEnd); + // Stack: result_ptr, value + EmitStoreIndirect(il, typeCode); + } + + #endregion + + #region Inline Mask IL Emission + + // Vector-related MethodInfos for np.where are cached in the partial CachedMethods class + // below (see "Where Kernel Methods" region at the end of this file). + + /// + /// Emit inline V256 mask creation. Stack: byte* -> Vector256{T} (as mask) + /// + private static void EmitInlineMaskCreationV256(ILGenerator il, int elementSize) + { + // Stack has: byte* pointing to condition bools + + switch (elementSize) + { + case 8: // double/long: load 4 bytes, expand to 4 qwords + // *(uint*)ptr + il.Emit(OpCodes.Ldind_U4); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarUInt); + // .AsByte() + il.Emit(OpCodes.Call, CachedMethods.V128UIntAsByte); + // Avx2.ConvertToVector256Int64(bytes) + il.Emit(OpCodes.Call, CachedMethods.Avx2ConvertToV256Int64); + // .AsUInt64() + il.Emit(OpCodes.Call, CachedMethods.V256LongAsULong); + // Vector256.Zero + il.Emit(OpCodes.Call, CachedMethods.V256GetZeroULong); + // Vector256.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, CachedMethods.V256GreaterThanULong); + break; + + case 4: // float/int: load 8 bytes, expand to 8 dwords + // *(ulong*)ptr + il.Emit(OpCodes.Ldind_I8); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarULong); + // .AsByte() + il.Emit(OpCodes.Call, CachedMethods.V128ULongAsByte); + // Avx2.ConvertToVector256Int32(bytes) + il.Emit(OpCodes.Call, CachedMethods.Avx2ConvertToV256Int32); + // .AsUInt32() + il.Emit(OpCodes.Call, CachedMethods.V256IntAsUInt); + // Vector256.Zero + il.Emit(OpCodes.Call, CachedMethods.V256GetZeroUInt); + // Vector256.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, CachedMethods.V256GreaterThanUInt); + break; + + case 2: // short/char: load 16 bytes, expand to 16 words + // Vector128.Load(ptr) + il.Emit(OpCodes.Call, CachedMethods.V128LoadByte); + // Avx2.ConvertToVector256Int16(bytes) + il.Emit(OpCodes.Call, CachedMethods.Avx2ConvertToV256Int16); + // .AsUInt16() + il.Emit(OpCodes.Call, CachedMethods.V256ShortAsUShort); + // Vector256.Zero + il.Emit(OpCodes.Call, CachedMethods.V256GetZeroUShort); + // Vector256.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, CachedMethods.V256GreaterThanUShort); + break; + + case 1: // byte/bool: load 32 bytes, compare directly + // Vector256.Load(ptr) + il.Emit(OpCodes.Call, CachedMethods.V256LoadByte); + // Vector256.Zero + il.Emit(OpCodes.Call, CachedMethods.V256GetZeroByte); + // Vector256.GreaterThan(vec, zero) + il.Emit(OpCodes.Call, CachedMethods.V256GreaterThanByte); + break; + + default: + throw new NotSupportedException($"Element size {elementSize} not supported"); + } + } + + /// + /// Emit inline V128 mask creation. Stack: byte* -> Vector128{T} (as mask) + /// + private static void EmitInlineMaskCreationV128(ILGenerator il, int elementSize) + { + switch (elementSize) + { + case 8: // double/long: load 2 bytes, expand to 2 qwords + // *(ushort*)ptr + il.Emit(OpCodes.Ldind_U2); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarUShort); + // .AsByte() + il.Emit(OpCodes.Call, CachedMethods.V128UShortAsByte); + // Sse41.ConvertToVector128Int64(bytes) + il.Emit(OpCodes.Call, CachedMethods.Sse41ConvertToV128Int64); + // .AsUInt64() + il.Emit(OpCodes.Call, CachedMethods.V128LongAsULong); + // Vector128.Zero + il.Emit(OpCodes.Call, CachedMethods.V128GetZeroULong); + // Vector128.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, CachedMethods.V128GreaterThanULong); + break; + + case 4: // float/int: load 4 bytes, expand to 4 dwords + // *(uint*)ptr + il.Emit(OpCodes.Ldind_U4); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarUInt); + // .AsByte() + il.Emit(OpCodes.Call, CachedMethods.V128UIntAsByte); + // Sse41.ConvertToVector128Int32(bytes) + il.Emit(OpCodes.Call, CachedMethods.Sse41ConvertToV128Int32); + // .AsUInt32() + il.Emit(OpCodes.Call, CachedMethods.V128IntAsUInt); + // Vector128.Zero + il.Emit(OpCodes.Call, CachedMethods.V128GetZeroUInt); + // Vector128.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, CachedMethods.V128GreaterThanUInt); + break; + + case 2: // short/char: load 8 bytes, expand to 8 words + // *(ulong*)ptr + il.Emit(OpCodes.Ldind_I8); + // Vector128.CreateScalar(value) + il.Emit(OpCodes.Call, CachedMethods.V128CreateScalarULong); + // .AsByte() + il.Emit(OpCodes.Call, CachedMethods.V128ULongAsByte); + // Sse41.ConvertToVector128Int16(bytes) + il.Emit(OpCodes.Call, CachedMethods.Sse41ConvertToV128Int16); + // .AsUInt16() + il.Emit(OpCodes.Call, CachedMethods.V128ShortAsUShort); + // Vector128.Zero + il.Emit(OpCodes.Call, CachedMethods.V128GetZeroUShort); + // Vector128.GreaterThan(expanded, zero) + il.Emit(OpCodes.Call, CachedMethods.V128GreaterThanUShort); + break; + + case 1: // byte/bool: load 16 bytes, compare directly + // Vector128.Load(ptr) + il.Emit(OpCodes.Call, CachedMethods.V128LoadByte); + // Vector128.Zero + il.Emit(OpCodes.Call, CachedMethods.V128GetZeroByte); + // Vector128.GreaterThan(vec, zero) + il.Emit(OpCodes.Call, CachedMethods.V128GreaterThanByte); + break; + + default: + throw new NotSupportedException($"Element size {elementSize} not supported"); + } + } + + #endregion + + #region Scalar Fallback + + /// + /// Scalar fallback for where operation. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe void WhereScalar(bool* cond, T* x, T* y, T* result, long count) where T : unmanaged + { + for (long i = 0; i < count; i++) + { + result[i] = cond[i] ? x[i] : y[i]; + } + } + + #endregion + + // Per the CachedMethods pattern in ILKernelGenerator.cs, reflection lookups for np.where + // live alongside the other cached entries. Fail-fast at type init so a renamed API shows + // up immediately instead of NREs at first use. + private static partial class CachedMethods + { + #region Where Kernel Methods + + private static MethodInfo FindGenericMethod(Type container, string name, int? paramCount = null) + { + foreach (var m in container.GetMethods()) + { + if (m.Name == name && m.IsGenericMethodDefinition && + (paramCount is null || m.GetParameters().Length == paramCount.Value)) + return m; + } + throw new MissingMethodException(container.FullName, name); + } + + private static MethodInfo FindMethodExact(Type container, string name, Type[] argTypes) + => container.GetMethod(name, argTypes) + ?? throw new MissingMethodException(container.FullName, name); + + private static MethodInfo GetZeroGetter(Type vectorOfT) + => vectorOfT.GetProperty("Zero")?.GetMethod + ?? throw new MissingMethodException(vectorOfT.FullName, "get_Zero"); + + // Generic definitions — caller must MakeGenericMethod(typeof(T)) before emitting. + public static readonly MethodInfo V256LoadGeneric = FindGenericMethod(typeof(Vector256), "Load", 1); + public static readonly MethodInfo V256StoreGeneric = FindGenericMethod(typeof(Vector256), "Store", 2); + public static readonly MethodInfo V256ConditionalSelectGeneric = FindGenericMethod(typeof(Vector256), "ConditionalSelect"); + + public static readonly MethodInfo V128LoadGeneric = FindGenericMethod(typeof(Vector128), "Load", 1); + public static readonly MethodInfo V128StoreGeneric = FindGenericMethod(typeof(Vector128), "Store", 2); + public static readonly MethodInfo V128ConditionalSelectGeneric = FindGenericMethod(typeof(Vector128), "ConditionalSelect"); + + // Already-specialised generic methods used during mask creation. + public static readonly MethodInfo V256LoadByte = FindGenericMethod(typeof(Vector256), "Load").MakeGenericMethod(typeof(byte)); + public static readonly MethodInfo V128LoadByte = FindGenericMethod(typeof(Vector128), "Load").MakeGenericMethod(typeof(byte)); + + public static readonly MethodInfo V128CreateScalarUInt = FindGenericMethod(typeof(Vector128), "CreateScalar").MakeGenericMethod(typeof(uint)); + public static readonly MethodInfo V128CreateScalarULong = FindGenericMethod(typeof(Vector128), "CreateScalar").MakeGenericMethod(typeof(ulong)); + public static readonly MethodInfo V128CreateScalarUShort = FindGenericMethod(typeof(Vector128), "CreateScalar").MakeGenericMethod(typeof(ushort)); + + public static readonly MethodInfo V128UIntAsByte = FindGenericMethod(typeof(Vector128), "AsByte").MakeGenericMethod(typeof(uint)); + public static readonly MethodInfo V128ULongAsByte = FindGenericMethod(typeof(Vector128), "AsByte").MakeGenericMethod(typeof(ulong)); + public static readonly MethodInfo V128UShortAsByte = FindGenericMethod(typeof(Vector128), "AsByte").MakeGenericMethod(typeof(ushort)); + + public static readonly MethodInfo V256LongAsULong = FindGenericMethod(typeof(Vector256), "AsUInt64").MakeGenericMethod(typeof(long)); + public static readonly MethodInfo V256IntAsUInt = FindGenericMethod(typeof(Vector256), "AsUInt32").MakeGenericMethod(typeof(int)); + public static readonly MethodInfo V256ShortAsUShort = FindGenericMethod(typeof(Vector256), "AsUInt16").MakeGenericMethod(typeof(short)); + + public static readonly MethodInfo V128LongAsULong = FindGenericMethod(typeof(Vector128), "AsUInt64").MakeGenericMethod(typeof(long)); + public static readonly MethodInfo V128IntAsUInt = FindGenericMethod(typeof(Vector128), "AsUInt32").MakeGenericMethod(typeof(int)); + public static readonly MethodInfo V128ShortAsUShort = FindGenericMethod(typeof(Vector128), "AsUInt16").MakeGenericMethod(typeof(short)); + + public static readonly MethodInfo V256GreaterThanULong = FindGenericMethod(typeof(Vector256), "GreaterThan").MakeGenericMethod(typeof(ulong)); + public static readonly MethodInfo V256GreaterThanUInt = FindGenericMethod(typeof(Vector256), "GreaterThan").MakeGenericMethod(typeof(uint)); + public static readonly MethodInfo V256GreaterThanUShort = FindGenericMethod(typeof(Vector256), "GreaterThan").MakeGenericMethod(typeof(ushort)); + public static readonly MethodInfo V256GreaterThanByte = FindGenericMethod(typeof(Vector256), "GreaterThan").MakeGenericMethod(typeof(byte)); + + public static readonly MethodInfo V128GreaterThanULong = FindGenericMethod(typeof(Vector128), "GreaterThan").MakeGenericMethod(typeof(ulong)); + public static readonly MethodInfo V128GreaterThanUInt = FindGenericMethod(typeof(Vector128), "GreaterThan").MakeGenericMethod(typeof(uint)); + public static readonly MethodInfo V128GreaterThanUShort = FindGenericMethod(typeof(Vector128), "GreaterThan").MakeGenericMethod(typeof(ushort)); + public static readonly MethodInfo V128GreaterThanByte = FindGenericMethod(typeof(Vector128), "GreaterThan").MakeGenericMethod(typeof(byte)); + + // Non-generic exact overloads on Avx2/Sse41 for byte-lane sign-extend expansion. + public static readonly MethodInfo Avx2ConvertToV256Int64 = FindMethodExact(typeof(Avx2), "ConvertToVector256Int64", new[] { typeof(Vector128) }); + public static readonly MethodInfo Avx2ConvertToV256Int32 = FindMethodExact(typeof(Avx2), "ConvertToVector256Int32", new[] { typeof(Vector128) }); + public static readonly MethodInfo Avx2ConvertToV256Int16 = FindMethodExact(typeof(Avx2), "ConvertToVector256Int16", new[] { typeof(Vector128) }); + public static readonly MethodInfo Sse41ConvertToV128Int64 = FindMethodExact(typeof(Sse41), "ConvertToVector128Int64", new[] { typeof(Vector128) }); + public static readonly MethodInfo Sse41ConvertToV128Int32 = FindMethodExact(typeof(Sse41), "ConvertToVector128Int32", new[] { typeof(Vector128) }); + public static readonly MethodInfo Sse41ConvertToV128Int16 = FindMethodExact(typeof(Sse41), "ConvertToVector128Int16", new[] { typeof(Vector128) }); + + // Vector*.Zero property getters — emitted as a call, not a field load, so we cache the getter MethodInfo. + public static readonly MethodInfo V256GetZeroULong = GetZeroGetter(typeof(Vector256)); + public static readonly MethodInfo V256GetZeroUInt = GetZeroGetter(typeof(Vector256)); + public static readonly MethodInfo V256GetZeroUShort = GetZeroGetter(typeof(Vector256)); + public static readonly MethodInfo V256GetZeroByte = GetZeroGetter(typeof(Vector256)); + public static readonly MethodInfo V128GetZeroULong = GetZeroGetter(typeof(Vector128)); + public static readonly MethodInfo V128GetZeroUInt = GetZeroGetter(typeof(Vector128)); + public static readonly MethodInfo V128GetZeroUShort = GetZeroGetter(typeof(Vector128)); + public static readonly MethodInfo V128GetZeroByte = GetZeroGetter(typeof(Vector128)); + + #endregion + } + } +} diff --git a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.cs b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.cs index 37536cf0..134ae6a0 100644 --- a/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.cs +++ b/src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.cs @@ -290,7 +290,7 @@ public static partial class ILKernelGenerator /// Caching these avoids repeated GetMethod() lookups during kernel generation. /// All fields use ?? throw to fail fast at type load if a method is not found. /// - private static class CachedMethods + private static partial class CachedMethods { // Math methods (double versions) public static readonly MethodInfo MathPow = typeof(Math).GetMethod(nameof(Math.Pow), new[] { typeof(double), typeof(double) }) diff --git a/src/NumSharp.Core/Creation/np.asanyarray.cs b/src/NumSharp.Core/Creation/np.asanyarray.cs index 5e83dc00..e575250c 100644 --- a/src/NumSharp.Core/Creation/np.asanyarray.cs +++ b/src/NumSharp.Core/Creation/np.asanyarray.cs @@ -1,4 +1,9 @@ using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; namespace NumSharp { @@ -18,29 +23,341 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support case null: throw new ArgumentNullException(nameof(a)); case NDArray nd: - return nd; + if (dtype == null || Equals(nd.dtype, dtype)) + return nd; + return nd.astype(dtype, true); + case object[] objArr: + // object[] has no fixed dtype — route through type-promotion path. + // new NDArray(object[]) throws NotSupportedException since object isn't a + // supported element type. + ret = ConvertNonGenericEnumerable(objArr); + if (ret is null) + throw new NotSupportedException($"Unable to resolve asanyarray for object[] (length {objArr.Length}): element type is not a supported NumSharp dtype."); + break; case Array array: ret = new NDArray(array); break; case string str: ret = str; //implicit cast located in NDArray.Implicit.Array break; + + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + case IEnumerable e: ret = np.array(ToArrayFast(e)); break; + default: var type = a.GetType(); - //is it a scalar if (type.IsPrimitive || type == typeof(decimal)) { ret = NDArray.Scalar(a); break; } - throw new NotSupportedException($"Unable resolve asanyarray for type {a.GetType().Name}"); + // Memory/ReadOnlyMemory do not implement IEnumerable. + if (type.IsGenericType) + { + var genericDef = type.GetGenericTypeDefinition(); + if (genericDef == typeof(Memory<>) || genericDef == typeof(ReadOnlyMemory<>)) + { + ret = ConvertMemory(a, type); + if (ret is not null) + break; + } + } + + if (a is ITuple tuple) + { + ret = ConvertTuple(tuple); + if (ret is not null) + break; + } + + if (a is IEnumerable enumerable) + { + ret = ConvertNonGenericEnumerable(enumerable); + if (ret is not null) + break; + } + + if (a is IEnumerator enumerator) + { + ret = ConvertEnumerator(enumerator); + if (ret is not null) + break; + } + + throw new NotSupportedException($"Unable to resolve asanyarray for type {type.Name}"); } - if (dtype != null && a.GetType() != dtype) + if (dtype != null && !Equals(ret.dtype, dtype)) return ret.astype(dtype, true); return ret; } + + /// + /// Copies an into a freshly allocated []. + /// Specialised for List<T> and ICollection<T> to skip the enumerator and to + /// use since we overwrite every slot. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static T[] ToArrayFast(IEnumerable source) + { + if (source is List list) + { + var span = CollectionsMarshal.AsSpan(list); + var arr = GC.AllocateUninitializedArray(span.Length); + span.CopyTo(arr); + return arr; + } + + if (source is ICollection collection) + { + var arr = GC.AllocateUninitializedArray(collection.Count); + collection.CopyTo(arr, 0); + return arr; + } + + return source.ToArray(); + } + + /// + /// Converts Memory<T> or ReadOnlyMemory<T> to an NDArray. + /// Uses Span.CopyTo + GC.AllocateUninitializedArray for optimal performance. + /// + private static NDArray ConvertMemory(object a, Type type) + { + var elementType = type.GetGenericArguments()[0]; + var isReadOnly = type.GetGenericTypeDefinition() == typeof(ReadOnlyMemory<>); + + if (elementType == typeof(bool)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(byte)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(short)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(ushort)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(int)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(uint)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(long)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(ulong)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(char)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(float)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(double)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + if (elementType == typeof(decimal)) return np.array(SpanToArrayFast(isReadOnly ? ((ReadOnlyMemory)a).Span : ((Memory)a).Span)); + + return null; + } + + /// + /// Optimized Span to Array conversion using GC.AllocateUninitializedArray. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static T[] SpanToArrayFast(ReadOnlySpan span) + { + var arr = GC.AllocateUninitializedArray(span.Length); + span.CopyTo(arr); + return arr; + } + + /// + /// Converts a non-generic IEnumerable to an NDArray. + /// Element type is detected from the first item. + /// + private static NDArray ConvertNonGenericEnumerable(IEnumerable enumerable) + => ConvertEnumerator(enumerable.GetEnumerator()); + + /// + /// Converts a non-generic IEnumerator to an NDArray. + /// Element type is detected from items with NumPy-like type promotion. + /// Empty collections return empty double[] to match NumPy's float64 default. + /// + private static NDArray ConvertEnumerator(IEnumerator enumerator) + { + List items = enumerator is ICollection collection + ? new List(collection.Count) + : new List(); + + while (enumerator.MoveNext()) + { + var item = enumerator.Current; + if (item != null) + items.Add(item); + } + + if (items.Count == 0) + return np.array(Array.Empty()); + + var elementType = FindCommonNumericType(items); + return ConvertObjectListToNDArray(items, elementType); + } + + /// + /// Finds the common numeric type for a list of objects (NumPy-like promotion). + /// Uses existing _FindCommonType_Scalar for consistent type promotion. + /// + private static Type FindCommonNumericType(List items) + { + var span = CollectionsMarshal.AsSpan(items); + + Type firstType = null; + + // At most 12 unique NPTypeCode values exist; bound the stackalloc accordingly + // (otherwise large user lists could blow the stack). + Span typeCodes = stackalloc NPTypeCode[12]; + int uniqueCount = 0; + uint seenMask = 0; + + for (int i = 0; i < span.Length; i++) + { + var t = span[i].GetType(); + firstType ??= t; + + // decimal wins everything in NumPy promotion + if (t == typeof(decimal)) + return typeof(decimal); + + var code = t.GetTypeCode(); + var bit = 1u << (int)code; + if ((seenMask & bit) == 0) + { + seenMask |= bit; + typeCodes[uniqueCount++] = code; + } + } + + if (uniqueCount == 1) + return firstType ?? typeof(double); + + var resultCode = _FindCommonType_Scalar(typeCodes.Slice(0, uniqueCount).ToArray()); + return resultCode.AsType(); + } + + /// + /// Converts a Tuple or ValueTuple to an NDArray via the ITuple interface. + /// + private static NDArray ConvertTuple(ITuple tuple) + { + if (tuple.Length == 0) + return np.array(Array.Empty()); + + var items = new List(tuple.Length); + + for (int i = 0; i < tuple.Length; i++) + { + var item = tuple[i]; + if (item != null) + items.Add(item); + } + + if (items.Count == 0) + return np.array(Array.Empty()); + + var elementType = FindCommonNumericType(items); + return ConvertObjectListToNDArray(items, elementType); + } + + /// + /// Converts a list of objects to an NDArray of the specified element type. + /// The pattern is T v ? v : Convert.ToT(item) takes the direct-cast fast path for + /// homogeneous collections while still handling mixed-type promotion via Convert. + /// + private static NDArray ConvertObjectListToNDArray(List items, Type elementType) + { + var span = CollectionsMarshal.AsSpan(items); + + if (elementType == typeof(bool)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is bool v ? v : Convert.ToBoolean(span[i]); + return np.array(arr); + } + if (elementType == typeof(byte)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is byte v ? v : Convert.ToByte(span[i]); + return np.array(arr); + } + if (elementType == typeof(short)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is short v ? v : Convert.ToInt16(span[i]); + return np.array(arr); + } + if (elementType == typeof(ushort)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is ushort v ? v : Convert.ToUInt16(span[i]); + return np.array(arr); + } + if (elementType == typeof(int)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is int v ? v : Convert.ToInt32(span[i]); + return np.array(arr); + } + if (elementType == typeof(uint)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is uint v ? v : Convert.ToUInt32(span[i]); + return np.array(arr); + } + if (elementType == typeof(long)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is long v ? v : Convert.ToInt64(span[i]); + return np.array(arr); + } + if (elementType == typeof(ulong)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is ulong v ? v : Convert.ToUInt64(span[i]); + return np.array(arr); + } + if (elementType == typeof(char)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is char v ? v : Convert.ToChar(span[i]); + return np.array(arr); + } + if (elementType == typeof(float)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is float v ? v : Convert.ToSingle(span[i]); + return np.array(arr); + } + if (elementType == typeof(double)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is double v ? v : Convert.ToDouble(span[i]); + return np.array(arr); + } + if (elementType == typeof(decimal)) + { + var arr = GC.AllocateUninitializedArray(span.Length); + for (int i = 0; i < span.Length; i++) + arr[i] = span[i] is decimal v ? v : Convert.ToDecimal(span[i]); + return np.array(arr); + } + + return null; // Unsupported element type + } } } diff --git a/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs new file mode 100644 index 00000000..efb42918 --- /dev/null +++ b/test/NumSharp.UnitTest/Backends/Kernels/WhereSimdTests.cs @@ -0,0 +1,515 @@ +using System; +using System.Diagnostics; +using NumSharp.Backends.Kernels; +using NumSharp.UnitTest.Utilities; +using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; + +namespace NumSharp.UnitTest.Backends.Kernels +{ + /// + /// Tests for SIMD-optimized np.where implementation. + /// Verifies correctness of the SIMD path for all supported dtypes. + /// + [TestClass] + public class WhereSimdTests + { + #region SIMD Correctness + + [TestMethod] + public void Where_Simd_Float32_Correctness() + { + var rng = np.random.RandomState(42); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.rand(size).astype(NPTypeCode.Single); + var y = rng.rand(size).astype(NPTypeCode.Single); + + var result = np.where(cond, x, y); + + // Verify correctness manually + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (float)x[i] : (float)y[i]; + Assert.AreEqual(expected, (float)result[i], 1e-6f, $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_Float64_Correctness() + { + var rng = np.random.RandomState(43); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.rand(size); + var y = rng.rand(size); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (double)x[i] : (double)y[i]; + Assert.AreEqual(expected, (double)result[i], 1e-10, $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_Int32_Correctness() + { + var rng = np.random.RandomState(44); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.randint(0, 1000, new[] { size }); + var y = rng.randint(0, 1000, new[] { size }); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (int)x[i] : (int)y[i]; + Assert.AreEqual(expected, (int)result[i], $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_Int64_Correctness() + { + var rng = np.random.RandomState(45); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.Int64); + var y = np.arange(size, size * 2).astype(NPTypeCode.Int64); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (long)x[i] : (long)y[i]; + Assert.AreEqual(expected, (long)result[i], $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_Byte_Correctness() + { + var rng = np.random.RandomState(46); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = (rng.rand(size) * 255).astype(NPTypeCode.Byte); + var y = (rng.rand(size) * 255).astype(NPTypeCode.Byte); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (byte)x[i] : (byte)y[i]; + Assert.AreEqual(expected, (byte)result[i], $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_Int16_Correctness() + { + var rng = np.random.RandomState(47); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.Int16); + var y = np.arange(size, size * 2).astype(NPTypeCode.Int16); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (short)x[i] : (short)y[i]; + Assert.AreEqual(expected, (short)result[i], $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_UInt16_Correctness() + { + var rng = np.random.RandomState(48); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.UInt16); + var y = np.arange(size, size * 2).astype(NPTypeCode.UInt16); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (ushort)x[i] : (ushort)y[i]; + Assert.AreEqual(expected, (ushort)result[i], $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_UInt32_Correctness() + { + var rng = np.random.RandomState(49); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.UInt32); + var y = np.arange(size, size * 2).astype(NPTypeCode.UInt32); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (uint)x[i] : (uint)y[i]; + Assert.AreEqual(expected, (uint)result[i], $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_UInt64_Correctness() + { + var rng = np.random.RandomState(50); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = np.arange(size).astype(NPTypeCode.UInt64); + var y = np.arange(size, size * 2).astype(NPTypeCode.UInt64); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (ulong)x[i] : (ulong)y[i]; + Assert.AreEqual(expected, (ulong)result[i], $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_Boolean_Correctness() + { + var rng = np.random.RandomState(51); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var x = rng.rand(size) > 0.3; + var y = rng.rand(size) > 0.7; + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (bool)x[i] : (bool)y[i]; + Assert.AreEqual(expected, (bool)result[i], $"Mismatch at index {i}"); + } + } + + [TestMethod] + public void Where_Simd_Char_Correctness() + { + var rng = np.random.RandomState(52); + var size = 1000; + var cond = rng.rand(size) > 0.5; + var xData = new char[size]; + var yData = new char[size]; + for (int i = 0; i < size; i++) + { + xData[i] = (char)('A' + (i % 26)); + yData[i] = (char)('a' + (i % 26)); + } + var x = np.array(xData); + var y = np.array(yData); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (char)x[i] : (char)y[i]; + Assert.AreEqual(expected, (char)result[i], $"Mismatch at index {i}"); + } + } + + #endregion + + #region Path Selection + + [TestMethod] + public void Where_NonContiguous_Works() + { + // Sliced arrays are non-contiguous, should work correctly + var baseArr = np.arange(20); + var cond = (baseArr % 2 == 0)["::2"]; // Sliced: [true, true, true, true, true, true, true, true, true, true] + var x = np.ones(10, NPTypeCode.Int32); + var y = np.zeros(10, NPTypeCode.Int32); + + var result = np.where(cond, x, y); + + Assert.AreEqual(10, result.size); + // All true -> all from x + for (int i = 0; i < 10; i++) + { + Assert.AreEqual(1, (int)result[i]); + } + } + + [TestMethod] + public void Where_Broadcast_Works() + { + // Broadcasted arrays + // cond shape (3,) broadcasts to (3,3): [[T,F,T],[T,F,T],[T,F,T]] + // x shape (3,1) broadcasts to (3,3): [[1,1,1],[2,2,2],[3,3,3]] + // y shape (1,3) broadcasts to (3,3): [[10,20,30],[10,20,30],[10,20,30]] + var cond = np.array(new[] { true, false, true }); + var x = np.array(new int[,] { { 1 }, { 2 }, { 3 } }); + var y = np.array(new int[,] { { 10, 20, 30 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(3, 3); + // Verify values: result[i,j] = cond[j] ? x[i,0] : y[0,j] + Assert.AreEqual(1, (int)result[0, 0]); // cond[0]=true -> x=1 + Assert.AreEqual(20, (int)result[0, 1]); // cond[1]=false -> y=20 + Assert.AreEqual(1, (int)result[0, 2]); // cond[2]=true -> x=1 + Assert.AreEqual(2, (int)result[1, 0]); // cond[0]=true -> x=2 + Assert.AreEqual(20, (int)result[1, 1]); // cond[1]=false -> y=20 + } + + [TestMethod] + public void Where_Decimal_Works() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new decimal[] { 1.1m, 2.2m, 3.3m }); + var y = np.array(new decimal[] { 10.1m, 20.2m, 30.3m }); + + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(decimal), result.dtype); + Assert.AreEqual(1.1m, (decimal)result[0]); + Assert.AreEqual(20.2m, (decimal)result[1]); + Assert.AreEqual(3.3m, (decimal)result[2]); + } + + [TestMethod] + public void Where_NonBoolCondition_Works() + { + // Non-bool condition requires truthiness check + var cond = np.array(new[] { 0, 1, 2, 0 }); // int condition + var result = np.where(cond, 100, -100); + + result.Should().BeOfValues(-100, 100, 100, -100); + } + + #endregion + + #region Edge Cases + + [TestMethod] + public void Where_Simd_SmallArray() + { + // Array smaller than vector width + var cond = np.array(new[] { true, false, true }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + + var result = np.where(cond, x, y); + + result.Should().BeOfValues(1, 20, 3); + } + + [TestMethod] + public void Where_Simd_VectorAlignedSize() + { + var rng = np.random.RandomState(53); + // Size exactly matches vector width (no scalar tail) + var size = 32; // V256 byte count + var cond = rng.rand(size) > 0.5; + var x = np.ones(size, NPTypeCode.Byte); + var y = np.zeros(size, NPTypeCode.Byte); + + var result = np.where(cond, x, y); + + Assert.AreEqual(size, result.size); + for (int i = 0; i < size; i++) + { + var expected = (bool)cond[i] ? (byte)1 : (byte)0; + Assert.AreEqual(expected, (byte)result[i]); + } + } + + [TestMethod] + public void Where_Simd_WithScalarTail() + { + // Size that requires scalar tail processing + var size = 35; // 32 + 3 tail for bytes + var cond = np.ones(size, NPTypeCode.Boolean); + var x = np.full(size, (byte)255); + var y = np.zeros(size, NPTypeCode.Byte); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual((byte)255, (byte)result[i], $"Mismatch at {i}"); + } + } + + [TestMethod] + public void Where_Simd_AllTrue() + { + var size = 100; + var cond = np.ones(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.full(size, -1L); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual((long)i, (long)result[i]); + } + } + + [TestMethod] + public void Where_Simd_AllFalse() + { + var size = 100; + var cond = np.zeros(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.full(size, -1L); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual(-1L, (long)result[i]); + } + } + + [TestMethod] + public void Where_Simd_Alternating() + { + var size = 100; + var condData = new bool[size]; + for (int i = 0; i < size; i++) + condData[i] = i % 2 == 0; + var cond = np.array(condData); + var x = np.ones(size, NPTypeCode.Int32); + var y = np.zeros(size, NPTypeCode.Int32); + + var result = np.where(cond, x, y); + + for (int i = 0; i < size; i++) + { + Assert.AreEqual(i % 2 == 0 ? 1 : 0, (int)result[i], $"Mismatch at {i}"); + } + } + + [TestMethod] + public void Where_Simd_NaN_Propagates() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new[] { double.NaN, 1.0, 2.0 }); + var y = np.array(new[] { 0.0, double.NaN, 0.0 }); + + var result = np.where(cond, x, y); + + Assert.IsTrue(double.IsNaN((double)result[0])); // NaN from x + Assert.IsTrue(double.IsNaN((double)result[1])); // NaN from y + Assert.AreEqual(2.0, (double)result[2], 1e-10); + } + + [TestMethod] + public void Where_Simd_Infinity() + { + var cond = np.array(new[] { true, false, true, false }); + var x = np.array(new[] { double.PositiveInfinity, 0.0, double.NegativeInfinity, 0.0 }); + var y = np.array(new[] { 0.0, double.PositiveInfinity, 0.0, double.NegativeInfinity }); + + var result = np.where(cond, x, y); + + Assert.AreEqual(double.PositiveInfinity, (double)result[0]); + Assert.AreEqual(double.PositiveInfinity, (double)result[1]); + Assert.AreEqual(double.NegativeInfinity, (double)result[2]); + Assert.AreEqual(double.NegativeInfinity, (double)result[3]); + } + + #endregion + + #region Performance Sanity Check + + [TestMethod] + public void Where_Simd_LargeArray_Correctness() + { + var rng = np.random.RandomState(54); + var size = 100_000; + var cond = rng.rand(size) > 0.5; + var x = np.ones(size, NPTypeCode.Double); + var y = np.zeros(size, NPTypeCode.Double); + + var result = np.where(cond, x, y); + + // Spot check + for (int i = 0; i < 100; i++) + { + var expected = (bool)cond[i] ? 1.0 : 0.0; + Assert.AreEqual(expected, (double)result[i], 1e-10); + } + + // Check last few elements (scalar tail) + for (int i = size - 10; i < size; i++) + { + var expected = (bool)cond[i] ? 1.0 : 0.0; + Assert.AreEqual(expected, (double)result[i], 1e-10); + } + } + + #endregion + + #region 2D/Multi-dimensional + + [TestMethod] + public void Where_Simd_2D_Contiguous() + { + var rng = np.random.RandomState(55); + // 2D contiguous array should use SIMD + var shape = new[] { 100, 100 }; + var cond = rng.rand(shape) > 0.5; + var x = np.ones(shape, NPTypeCode.Int32); + var y = np.zeros(shape, NPTypeCode.Int32); + + var result = np.where(cond, x, y); + + result.Should().BeShaped(100, 100); + + // Spot check + for (int i = 0; i < 10; i++) + { + for (int j = 0; j < 10; j++) + { + var expected = (bool)cond[i, j] ? 1 : 0; + Assert.AreEqual(expected, (int)result[i, j]); + } + } + } + + [TestMethod] + public void Where_Simd_3D_Contiguous() + { + var rng = np.random.RandomState(56); + var shape = new[] { 10, 20, 30 }; + var cond = rng.rand(shape) > 0.5; + var x = np.ones(shape, NPTypeCode.Single); + var y = np.zeros(shape, NPTypeCode.Single); + + var result = np.where(cond, x, y); + + result.Should().BeShaped(10, 20, 30); + + // Spot check + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 5; j++) + { + for (int k = 0; k < 5; k++) + { + var expected = (bool)cond[i, j, k] ? 1.0f : 0.0f; + Assert.AreEqual(expected, (float)result[i, j, k], 1e-6f); + } + } + } + } + + #endregion + } +} diff --git a/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs new file mode 100644 index 00000000..09b626fe --- /dev/null +++ b/test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs @@ -0,0 +1,819 @@ +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Collections.ObjectModel; +using System.Linq; +using AwesomeAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace NumSharp.UnitTest.Creation +{ + /// + /// Tests for np.asanyarray covering all built-in C# collection types. + /// + [TestClass] + public class np_asanyarray_tests + { + #region NDArray passthrough + + [TestMethod] + public void NDArray_ReturnsAsIs() + { + var original = np.array(1, 2, 3, 4, 5); + var result = np.asanyarray(original); + + // Should return the same instance (no copy) + ReferenceEquals(original, result).Should().BeTrue(); + } + + [TestMethod] + public void NDArray_WithDtype_ReturnsConverted() + { + var original = np.array(1, 2, 3, 4, 5); + var result = np.asanyarray(original, typeof(double)); + + result.dtype.Should().Be(typeof(double)); + result.Should().BeShaped(5); + } + + [TestMethod] + public void NDArray_WithSameDtype_ReturnsAsIs() + { + var original = np.array(1, 2, 3, 4, 5); + var result = np.asanyarray(original, typeof(int)); + + // Same dtype, should return same instance + ReferenceEquals(original, result).Should().BeTrue(); + } + + #endregion + + #region Array types + + [TestMethod] + public void Array_1D() + { + var arr = new int[] { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(arr); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + result.dtype.Should().Be(typeof(int)); + } + + [TestMethod] + public void Array_2D() + { + var arr = new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }; + var result = np.asanyarray(arr); + + result.Should().BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); + } + + [TestMethod] + public void Array_WithDtype() + { + var arr = new int[] { 1, 2, 3 }; + var result = np.asanyarray(arr, typeof(double)); + + result.dtype.Should().Be(typeof(double)); + result.Should().BeShaped(3); + } + + #endregion + + #region Scalars + + [TestMethod] + public void Scalar_Int() + { + var result = np.asanyarray(42); + + result.Should().BeScalar().And.BeOfValues(42); + result.dtype.Should().Be(typeof(int)); + } + + [TestMethod] + public void Scalar_Double() + { + var result = np.asanyarray(3.14); + + result.Should().BeScalar(); + result.dtype.Should().Be(typeof(double)); + } + + [TestMethod] + public void Scalar_Decimal() + { + var result = np.asanyarray(123.456m); + + result.Should().BeScalar(); + result.dtype.Should().Be(typeof(decimal)); + } + + [TestMethod] + public void Scalar_Bool() + { + var result = np.asanyarray(true); + + result.Should().BeScalar().And.BeOfValues(true); + result.dtype.Should().Be(typeof(bool)); + } + + #endregion + + #region List + + [TestMethod] + public void List_Int() + { + var list = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(list); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + result.dtype.Should().Be(typeof(int)); + } + + [TestMethod] + public void List_Double() + { + var list = new List { 1.1, 2.2, 3.3 }; + var result = np.asanyarray(list); + + result.Should().BeShaped(3); + result.dtype.Should().Be(typeof(double)); + } + + [TestMethod] + public void List_Bool() + { + var list = new List { true, false, true }; + var result = np.asanyarray(list); + + result.Should().BeShaped(3).And.BeOfValues(true, false, true); + result.dtype.Should().Be(typeof(bool)); + } + + [TestMethod] + public void List_Empty() + { + var list = new List(); + var result = np.asanyarray(list); + + result.Should().BeShaped(0); + result.dtype.Should().Be(typeof(int)); + } + + [TestMethod] + public void List_WithDtype() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list, typeof(float)); + + result.dtype.Should().Be(typeof(float)); + result.Should().BeShaped(3); + } + + #endregion + + #region IList / ICollection / IEnumerable + + [TestMethod] + public void IList_Int() + { + IList list = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(list); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [TestMethod] + public void ICollection_Int() + { + ICollection collection = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(collection); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [TestMethod] + public void IEnumerable_Int() + { + IEnumerable enumerable = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(enumerable); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [TestMethod] + public void IEnumerable_FromLinq() + { + var enumerable = Enumerable.Range(1, 5); + var result = np.asanyarray(enumerable); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [TestMethod] + public void IEnumerable_FromLinqSelect() + { + var enumerable = new[] { 1, 2, 3 }.Select(x => x * 2); + var result = np.asanyarray(enumerable); + + result.Should().BeShaped(3).And.BeOfValues(2, 4, 6); + } + + #endregion + + #region IReadOnlyList / IReadOnlyCollection + + [TestMethod] + public void IReadOnlyList_Int() + { + IReadOnlyList list = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(list); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [TestMethod] + public void IReadOnlyCollection_Int() + { + IReadOnlyCollection collection = new List { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(collection); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + #endregion + + #region ReadOnlyCollection + + [TestMethod] + public void ReadOnlyCollection_Int() + { + var collection = new ReadOnlyCollection(new List { 1, 2, 3, 4, 5 }); + var result = np.asanyarray(collection); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + #endregion + + #region LinkedList + + [TestMethod] + public void LinkedList_Int() + { + var linkedList = new LinkedList(); + linkedList.AddLast(1); + linkedList.AddLast(2); + linkedList.AddLast(3); + var result = np.asanyarray(linkedList); + + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + } + + #endregion + + #region HashSet / SortedSet + + [TestMethod] + public void HashSet_Int() + { + var set = new HashSet { 3, 1, 4, 1, 5, 9 }; // Duplicates removed + var result = np.asanyarray(set); + + result.size.Should().Be(5); // 1, 3, 4, 5, 9 (no duplicates) + result.dtype.Should().Be(typeof(int)); + } + + [TestMethod] + public void SortedSet_Int() + { + var set = new SortedSet { 3, 1, 4, 1, 5, 9 }; + var result = np.asanyarray(set); + + result.Should().BeShaped(5).And.BeOfValues(1, 3, 4, 5, 9); // Sorted, no duplicates + } + + #endregion + + #region Queue / Stack + + [TestMethod] + public void Queue_Int() + { + var queue = new Queue(); + queue.Enqueue(1); + queue.Enqueue(2); + queue.Enqueue(3); + var result = np.asanyarray(queue); + + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + } + + [TestMethod] + public void Stack_Int() + { + var stack = new Stack(); + stack.Push(1); + stack.Push(2); + stack.Push(3); + var result = np.asanyarray(stack); + + result.Should().BeShaped(3).And.BeOfValues(3, 2, 1); // LIFO order + } + + #endregion + + #region ArraySegment + + [TestMethod] + public void ArraySegment_Int() + { + var array = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + var segment = new ArraySegment(array, 2, 5); // Elements 2,3,4,5,6 + var result = np.asanyarray(segment); + + result.Should().BeShaped(5).And.BeOfValues(2, 3, 4, 5, 6); + } + + [TestMethod] + public void ArraySegment_Empty() + { + var array = new int[] { 1, 2, 3 }; + var segment = new ArraySegment(array, 0, 0); + var result = np.asanyarray(segment); + + result.Should().BeShaped(0); + } + + [TestMethod] + public void ArraySegment_Full() + { + var array = new int[] { 1, 2, 3, 4, 5 }; + var segment = new ArraySegment(array); + var result = np.asanyarray(segment); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + #endregion + + #region Memory / ReadOnlyMemory + + [TestMethod] + public void Memory_Int() + { + var array = new int[] { 1, 2, 3, 4, 5 }; + var memory = new Memory(array, 1, 3); // Elements 2,3,4 + var result = np.asanyarray(memory); + + result.Should().BeShaped(3).And.BeOfValues(2, 3, 4); + } + + [TestMethod] + public void ReadOnlyMemory_Int() + { + var array = new int[] { 1, 2, 3, 4, 5 }; + var memory = new ReadOnlyMemory(array, 1, 3); // Elements 2,3,4 + var result = np.asanyarray(memory); + + result.Should().BeShaped(3).And.BeOfValues(2, 3, 4); + } + + #endregion + + #region ImmutableArray / ImmutableList + + [TestMethod] + public void ImmutableArray_Int() + { + var immutableArray = ImmutableArray.Create(1, 2, 3, 4, 5); + var result = np.asanyarray(immutableArray); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [TestMethod] + public void ImmutableList_Int() + { + var immutableList = ImmutableList.Create(1, 2, 3, 4, 5); + var result = np.asanyarray(immutableList); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + } + + [TestMethod] + public void ImmutableHashSet_Int() + { + var immutableSet = ImmutableHashSet.Create(3, 1, 4, 1, 5); + var result = np.asanyarray(immutableSet); + + result.size.Should().Be(4); // 1, 3, 4, 5 (no duplicates) + } + + #endregion + + #region All supported dtypes via List + + [TestMethod] + public void List_Byte() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(byte)); + result.Should().BeShaped(3); + } + + // Note: sbyte is NOT supported by NumSharp (12 supported types: bool, byte, short, ushort, int, uint, long, ulong, char, float, double, decimal) + + [TestMethod] + public void List_Short() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(short)); + result.Should().BeShaped(3); + } + + [TestMethod] + public void List_UShort() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(ushort)); + result.Should().BeShaped(3); + } + + [TestMethod] + public void List_UInt() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(uint)); + result.Should().BeShaped(3); + } + + [TestMethod] + public void List_Long() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(long)); + result.Should().BeShaped(3); + } + + [TestMethod] + public void List_ULong() + { + var list = new List { 1, 2, 3 }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(ulong)); + result.Should().BeShaped(3); + } + + [TestMethod] + public void List_Float() + { + var list = new List { 1.1f, 2.2f, 3.3f }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(float)); + result.Should().BeShaped(3); + } + + [TestMethod] + public void List_Char() + { + var list = new List { 'a', 'b', 'c' }; + var result = np.asanyarray(list); + result.dtype.Should().Be(typeof(char)); + result.Should().BeShaped(3); + } + + #endregion + + #region Error cases + + [TestMethod] + public void Null_ThrowsArgumentNullException() + { + Assert.ThrowsException(() => np.asanyarray(null)); + } + + [TestMethod] + public void UnsupportedType_ThrowsNotSupportedException() + { + // String collections are not supported (string is not primitive/decimal) + var stringList = new List { "a", "b", "c" }; + Assert.ThrowsException(() => np.asanyarray(stringList)); + } + + [TestMethod] + public void CustomClass_ThrowsNotSupportedException() + { + var customObject = new object(); + Assert.ThrowsException(() => np.asanyarray(customObject)); + } + + #endregion + + #region String special case + + [TestMethod] + public void String_CreatesCharArray() + { + var result = np.asanyarray("hello"); + + result.Should().BeShaped(5); + result.dtype.Should().Be(typeof(char)); + } + + #endregion + + #region Non-generic IEnumerable fallback + + [TestMethod] + public void ArrayList_Int() + { + var arrayList = new System.Collections.ArrayList { 1, 2, 3, 4, 5 }; + var result = np.asanyarray(arrayList); + + result.Should().BeShaped(5).And.BeOfValues(1, 2, 3, 4, 5); + result.dtype.Should().Be(typeof(int)); + } + + [TestMethod] + public void ArrayList_Double() + { + var arrayList = new System.Collections.ArrayList { 1.1, 2.2, 3.3 }; + var result = np.asanyarray(arrayList); + + result.Should().BeShaped(3); + result.dtype.Should().Be(typeof(double)); + } + + [TestMethod] + public void Hashtable_Keys() + { + var hashtable = new System.Collections.Hashtable { { 1, "a" }, { 2, "b" }, { 3, "c" } }; + var result = np.asanyarray(hashtable.Keys); + + result.size.Should().Be(3); + result.dtype.Should().Be(typeof(int)); + } + + #endregion + + #region IEnumerator fallback + + [TestMethod] + public void IEnumerator_Int() + { + static System.Collections.IEnumerator GetEnumerator() + { + yield return 10; + yield return 20; + yield return 30; + } + + var result = np.asanyarray(GetEnumerator()); + + result.Should().BeShaped(3).And.BeOfValues(10, 20, 30); + result.dtype.Should().Be(typeof(int)); + } + + #endregion + + #region NumPy Parity - Misaligned Behaviors + + /// + /// NumPy treats strings as scalar Unicode values, NumSharp treats as char arrays. + /// NumPy: np.asanyarray("hello") -> dtype=<U5, shape=(), ndim=0 (SCALAR) + /// NumSharp: dtype=Char, shape=(5), ndim=1 (ARRAY) + /// + [TestMethod] + [Misaligned] + public void String_IsCharArray_NotScalar() + { + var result = np.asanyarray("hello"); + + // NumSharp behavior: char array + result.ndim.Should().Be(1); + result.shape.Should().BeEquivalentTo(new[] { 5 }); + result.dtype.Should().Be(typeof(char)); + + // NumPy would be: ndim=0, shape=(), dtype= + /// NumPy stores sets as object scalars (not iterated). + /// NumSharp iterates sets and converts to array. + /// NumPy: np.asanyarray({1,2,3}) -> dtype=object, shape=() (SCALAR) + /// NumSharp: dtype=Int32, shape=(3) (ARRAY) + /// + [TestMethod] + [Misaligned] + public void HashSet_IsIterated_NotObjectScalar() + { + var set = new HashSet { 1, 2, 3 }; + var result = np.asanyarray(set); + + // NumSharp behavior: iterates and creates array + result.ndim.Should().Be(1); + result.size.Should().Be(3); + result.dtype.Should().Be(typeof(int)); + + // NumPy would be: dtype=object, shape=() (object scalar containing set) + } + + /// + /// NumPy stores generators as object scalars (NOT consumed). + /// NumSharp consumes IEnumerable and converts to array. + /// This is arguably more useful behavior for C#. + /// + [TestMethod] + [Misaligned] + public void LinqEnumerable_IsConsumed_NotObjectScalar() + { + var enumerable = new[] { 1, 2, 3 }.Select(x => x * 2); + var result = np.asanyarray(enumerable); + + // NumSharp behavior: consumes and creates array + result.ndim.Should().Be(1); + result.Should().BeShaped(3).And.BeOfValues(2, 4, 6); + + // NumPy generator would be: dtype=object, shape=() (NOT consumed) + } + + /// + /// For typed empty collections (List<T>), NumSharp preserves the generic type parameter. + /// NumPy defaults to float64 for untyped empty lists. + /// This is a design choice: C# generics provide type information that NumPy doesn't have. + /// + [TestMethod] + [Misaligned] + public void EmptyTypedList_PreservesTypeParameter() + { + var result = np.asanyarray(new List()); + + // NumSharp behavior: preserves int dtype from generic type parameter + result.dtype.Should().Be(typeof(int)); + result.shape.Should().BeEquivalentTo(new[] { 0 }); + + // NumPy would be: dtype=float64, shape=(0,) + // NumSharp can do better because C# generics provide the type at compile time + } + + #endregion + + #region Tuple support + + /// + /// C# ValueTuples are iterable like Python tuples. + /// NumPy: np.asanyarray((1,2,3)) -> dtype=int64, shape=(3,) + /// + [TestMethod] + public void ValueTuple_IsIterable() + { + var tuple = (1, 2, 3); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + result.dtype.Should().Be(typeof(int)); + } + + /// + /// C# Tuple class is iterable like Python tuples. + /// + [TestMethod] + public void Tuple_IsIterable() + { + var tuple = Tuple.Create(1, 2, 3); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + result.dtype.Should().Be(typeof(int)); + } + + [TestMethod] + public void ValueTuple_MixedTypes_PromotesToCommonType() + { + // Mixed int + double promotes to double (NumPy behavior) + var tuple = (1, 2.5, 3); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(3); + result.dtype.Should().Be(typeof(double)); // Promoted from int to double + } + + [TestMethod] + public void ValueTuple_IntAndBool_PromotesToInt() + { + // Mixed int + bool promotes to int (NumPy behavior) + var tuple = (1, true, 3); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(3); + result.dtype.Should().Be(typeof(int)); + } + + [TestMethod] + public void EmptyTuple_ReturnsEmptyDoubleArray() + { + var tuple = ValueTuple.Create(); + var result = np.asanyarray(tuple); + + result.Should().BeShaped(0); + result.dtype.Should().Be(typeof(double)); + } + + #endregion + + #region Empty non-generic collections + + /// + /// Empty non-generic collections return empty double[] (NumPy defaults to float64). + /// + [TestMethod] + public void EmptyArrayList_ReturnsEmptyDoubleArray() + { + var arrayList = new System.Collections.ArrayList(); + var result = np.asanyarray(arrayList); + + result.size.Should().Be(0); + result.ndim.Should().Be(1); + result.dtype.Should().Be(typeof(double)); // NumPy: float64 + } + + #endregion + + #region object[] regression + + [TestMethod] + public void ObjectArray_Homogeneous_Int() + { + var arr = new object[] { 1, 2, 3 }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(int)); + result.Should().BeShaped(3).And.BeOfValues(1, 2, 3); + } + + [TestMethod] + public void ObjectArray_MixedIntFloat_PromotesToDouble() + { + var arr = new object[] { 1, 2.5, 3 }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(double)); + result.Should().BeShaped(3).And.BeOfValues(1.0, 2.5, 3.0); + } + + [TestMethod] + public void ObjectArray_MixedBoolInt_PromotesToInt() + { + var arr = new object[] { true, 2, false }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(int)); + result.Should().BeShaped(3).And.BeOfValues(1, 2, 0); + } + + [TestMethod] + public void ObjectArray_Empty_ReturnsFloat64() + { + var arr = new object[0]; + var result = np.asanyarray(arr); + + result.size.Should().Be(0); + result.ndim.Should().Be(1); + result.dtype.Should().Be(typeof(double)); + } + + [TestMethod] + public void ObjectArray_AllFloat_PreservesSingle() + { + // Regression: an earlier FindCommonNumericType short-circuit promoted any float + // to double. NumPy preserves float32 for homogeneous float32 inputs. + var arr = new object[] { 1.5f, 2.5f, 3.5f }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(float)); + result.Should().BeShaped(3).And.BeOfValues(1.5f, 2.5f, 3.5f); + } + + [TestMethod] + public void ObjectArray_MixedIntAndFloat32_PromotesToDouble() + { + // int + float32 -> float64 per NumPy NEP50. + var arr = new object[] { 1, 2.5f, 3 }; + var result = np.asanyarray(arr); + + result.dtype.Should().Be(typeof(double)); + result.Should().BeShaped(3); + } + + #endregion + } +} diff --git a/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs new file mode 100644 index 00000000..eb889b7d --- /dev/null +++ b/test/NumSharp.UnitTest/Logic/np.where.BattleTest.cs @@ -0,0 +1,759 @@ +using System; +using System.Linq; +using NumSharp.UnitTest.Utilities; +using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; + +namespace NumSharp.UnitTest.Logic +{ + /// + /// Battle tests for np.where - edge cases, strided arrays, views, etc. + /// + /// These tests verify NumSharp behavior against NumPy 2.4.2. + /// + /// KNOWN DIFFERENCES FROM NUMPY 2.x: + /// + /// 1. Scalar Type Promotion (NEP50): + /// NumPy 2.x uses "weak scalar" semantics where Python scalars adopt array dtype. + /// NumSharp uses C# semantics where literals have fixed types (int=int32, etc). + /// + /// Example: np.where(cond, 1, uint8_array) + /// - NumPy 2.x: returns uint8 (weak scalar rule) + /// - NumSharp: returns int32 (C# int literal is int32) + /// + /// 2. Python int Scalar Default: + /// - NumPy: Python int → int64 (platform default) + /// - NumSharp: C# int literal → int32 + /// + /// 3. Missing sbyte (int8) support: + /// NumSharp does not support sbyte arrays (throws NotSupportedException). + /// + [TestClass] + public class np_where_BattleTest + { + #region Strided/Sliced Arrays + + [TestMethod] + public void Where_SlicedCondition() + { + // Sliced condition array (non-contiguous) + var arr = np.arange(10); + var cond = (arr % 2 == 0)["::2"]; // Every other even check: [T,T,T,T,T] + var x = np.ones(5, NPTypeCode.Int32); + var y = np.zeros(5, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + Assert.AreEqual(5, result.size); + result.Should().BeOfValues(1, 1, 1, 1, 1); + } + + [TestMethod] + public void Where_SlicedXY() + { + var cond = np.array(new[] { true, false, true }); + var x = np.arange(6)["::2"]; // [0, 2, 4] + var y = np.arange(6)["1::2"]; // [1, 3, 5] + var result = np.where(cond, x, y); + + result.Should().BeOfValues(0L, 3L, 4L); + } + + [TestMethod] + public void Where_TransposedArrays() + { + var cond = np.array(new bool[,] { { true, false }, { false, true } }).T; + var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }).T; + var y = np.array(new int[,] { { 10, 20 }, { 30, 40 } }).T; + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 2); + // After transpose: cond[0,0]=T, cond[0,1]=F, cond[1,0]=F, cond[1,1]=T + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(30, (int)result[0, 1]); + Assert.AreEqual(20, (int)result[1, 0]); + Assert.AreEqual(4, (int)result[1, 1]); + } + + [TestMethod] + public void Where_ReversedSlice() + { + var cond = np.array(new[] { true, false, true, false, true }); + var x = np.arange(5)["::-1"]; // [4, 3, 2, 1, 0] + var y = np.zeros(5, NPTypeCode.Int64); + var result = np.where(cond, x, y); + + // NumPy: [4, 0, 2, 0, 0] + result.Should().BeOfValues(4L, 0L, 2L, 0L, 0L); + } + + #endregion + + #region Complex Broadcasting + + [TestMethod] + public void Where_3Way_Broadcasting() + { + // cond: (2,1,1), x: (1,3,1), y: (1,1,4) -> result: (2,3,4) + var cond = np.array(new bool[,,] { { { true } }, { { false } } }); + var x = np.arange(3).reshape(1, 3, 1); + var y = (np.arange(4) * 10).reshape(1, 1, 4); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 3, 4); + // First "page" (cond=True): values from x broadcast + Assert.AreEqual(0, (long)result[0, 0, 0]); + Assert.AreEqual(0, (long)result[0, 0, 3]); + Assert.AreEqual(2, (long)result[0, 2, 0]); + // Second "page" (cond=False): values from y broadcast + Assert.AreEqual(0, (long)result[1, 0, 0]); + Assert.AreEqual(30, (long)result[1, 0, 3]); + Assert.AreEqual(30, (long)result[1, 2, 3]); + } + + [TestMethod] + public void Where_RowVector_ColVector_Broadcast() + { + // cond: (1,4), x: (3,1), y: scalar -> result: (3,4) + var cond = np.array(new bool[,] { { true, false, true, false } }); + var x = np.array(new int[,] { { 1 }, { 2 }, { 3 } }); + var result = np.where(cond, x, 0); + + result.Should().BeShaped(3, 4); + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(0, (int)result[0, 1]); + Assert.AreEqual(2, (int)result[1, 0]); + Assert.AreEqual(0, (int)result[1, 1]); + } + + [TestMethod] + public void Where_ScalarCondition_True() + { + // NumPy: np.where(True, [1,2,3], [4,5,6]) -> [1,2,3] + var result = np.where(np.array(true), np.array(new[] { 1, 2, 3 }), np.array(new[] { 4, 5, 6 })); + result.Should().BeOfValues(1, 2, 3); + } + + [TestMethod] + public void Where_ScalarCondition_False() + { + // NumPy: np.where(False, [1,2,3], [4,5,6]) -> [4,5,6] + var result = np.where(np.array(false), np.array(new[] { 1, 2, 3 }), np.array(new[] { 4, 5, 6 })); + result.Should().BeOfValues(4, 5, 6); + } + + #endregion + + #region Non-Boolean Conditions (Truthy/Falsy) + + [TestMethod] + public void Where_IntegerCondition_ZeroIsFalsy() + { + // NumPy: 0 is falsy, non-zero is truthy + var cond = np.array(new[] { 0, 1, 2, -1, 0 }); + var x = np.ones(5); + var y = np.zeros(5); + var result = np.where(cond, x, y); + + // NumPy: [0, 1, 1, 1, 0] + result.Should().BeOfValues(0.0, 1.0, 1.0, 1.0, 0.0); + } + + [TestMethod] + public void Where_FloatCondition_ZeroIsFalsy() + { + // NumPy: 0.0 is falsy + var cond = np.array(new[] { 0.0, 0.5, 1.0, -0.1, 0.0 }); + var x = np.ones(5); + var y = np.zeros(5); + var result = np.where(cond, x, y); + + // NumPy: [0, 1, 1, 1, 0] + result.Should().BeOfValues(0.0, 1.0, 1.0, 1.0, 0.0); + } + + [TestMethod] + public void Where_NaN_IsTruthy() + { + // NumPy: NaN is truthy (non-zero) + var cond = np.array(new[] { 0.0, double.NaN, 1.0 }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + // NumPy: [10, 2, 3] (NaN is truthy) + result.Should().BeOfValues(10, 2, 3); + } + + [TestMethod] + public void Where_Infinity_IsTruthy() + { + // NumPy: Inf and -Inf are truthy + var cond = np.array(new[] { 0.0, double.PositiveInfinity, double.NegativeInfinity }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + // NumPy: [10, 2, 3] + result.Should().BeOfValues(10, 2, 3); + } + + [TestMethod] + public void Where_NegativeZero_IsFalsy() + { + // NumPy: -0.0 == 0.0 in IEEE 754, so it's falsy + var cond = np.array(new[] { 0.0, -0.0, 1.0 }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + // NumPy: [10, 20, 3] (both 0.0 and -0.0 are falsy) + result.Should().BeOfValues(10, 20, 3); + } + + #endregion + + #region Numeric Edge Cases + + [TestMethod] + public void Where_NaN_Values() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new[] { double.NaN, 1.0, double.NaN }); + var y = np.array(new[] { 0.0, double.NaN, 0.0 }); + var result = np.where(cond, x, y); + + Assert.IsTrue(double.IsNaN((double)result[0])); // from x + Assert.IsTrue(double.IsNaN((double)result[1])); // from y + Assert.IsTrue(double.IsNaN((double)result[2])); // from x + } + + [TestMethod] + public void Where_Infinity_Values() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { double.PositiveInfinity, 1.0 }); + var y = np.array(new[] { 0.0, double.NegativeInfinity }); + var result = np.where(cond, x, y); + + Assert.AreEqual(double.PositiveInfinity, (double)result[0]); + Assert.AreEqual(double.NegativeInfinity, (double)result[1]); + } + + [TestMethod] + public void Where_MaxMin_Values() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { long.MaxValue, 0L }); + var y = np.array(new[] { 0L, long.MinValue }); + var result = np.where(cond, x, y); + + Assert.AreEqual(long.MaxValue, (long)result[0]); + Assert.AreEqual(long.MinValue, (long)result[1]); + } + + #endregion + + #region Single Arg Edge Cases + + [TestMethod] + public void Where_SingleArg_Float_Truthy() + { + // 0.0 is falsy, anything else (including -0.0, NaN, Inf) is truthy + // Note: -0.0 == 0.0 in IEEE 754, so it's falsy + var arr = np.array(new[] { 0.0, 1.0, -1.0, 0.5, -0.0 }); + var result = np.where(arr); + + // NumPy: indices [1, 2, 3] (-0.0 is falsy) + result[0].Should().BeOfValues(1L, 2L, 3L); + } + + [TestMethod] + public void Where_SingleArg_NaN_IsTruthy() + { + // NaN is non-zero, so it's truthy + var arr = np.array(new[] { 0.0, double.NaN, 0.0 }); + var result = np.where(arr); + + result[0].Should().BeOfValues(1L); + } + + [TestMethod] + public void Where_SingleArg_Infinity_IsTruthy() + { + // Inf values are truthy + var arr = np.array(new[] { 0.0, double.PositiveInfinity, double.NegativeInfinity, 0.0 }); + var result = np.where(arr); + + result[0].Should().BeOfValues(1L, 2L); + } + + [TestMethod] + public void Where_SingleArg_4D() + { + var arr = np.zeros(new[] { 2, 2, 2, 2 }, NPTypeCode.Int32); + arr[0, 1, 0, 1] = 1; + arr[1, 0, 1, 0] = 1; + var result = np.where(arr); + + Assert.AreEqual(4, result.Length); // 4 dimensions + Assert.AreEqual(2, result[0].size); // 2 non-zero elements + } + + [TestMethod] + public void Where_SingleArg_ReturnsInt64Indices() + { + // NumPy returns int64 for indices + var arr = np.array(new[] { 0, 1, 0, 2 }); + var result = np.where(arr); + + Assert.AreEqual(typeof(long), result[0].dtype); + } + + #endregion + + #region 0D Scalar Arrays + + [TestMethod] + public void Where_0D_AllScalars_Returns0D() + { + // NumPy: when all inputs are 0D, result is 0D + var cond = np.array(true).reshape(); // 0D + var x = np.array(42).reshape(); // 0D + var y = np.array(99).reshape(); // 0D + var result = np.where(cond, x, y); + + Assert.AreEqual(0, result.ndim); + Assert.AreEqual(42, (int)result.GetValue(0)); + } + + [TestMethod] + public void Where_0D_Cond_With_1D_Arrays() + { + // 0D condition broadcasts to match x/y shape + var cond = np.array(true).reshape(); // 0D + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(3); + result.Should().BeOfValues(1, 2, 3); + } + + #endregion + + #region Type Promotion (Array-to-Array) + + [TestMethod] + public void Where_TypePromotion_Bool_Int16() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new bool[] { true, false }); + var y = np.array(new short[] { 10, 20 }); + var result = np.where(cond, x, y); + + // NumPy: int16 + Assert.AreEqual(typeof(short), result.dtype); + } + + [TestMethod] + public void Where_TwoScalars_Byte_StaysByte() + { + // C# byte (like np.uint8) stays byte, not widened to int64 + var cond = np.array(new[] { true, false }); + var result = np.where(cond, (byte)1, (byte)0); + + Assert.AreEqual(typeof(byte), result.dtype); + Assert.AreEqual((byte)1, (byte)result[0]); + Assert.AreEqual((byte)0, (byte)result[1]); + } + + [TestMethod] + public void Where_TwoScalars_Short_StaysShort() + { + // C# short (like np.int16) stays short + var cond = np.array(new[] { true, false }); + var result = np.where(cond, (short)100, (short)200); + + Assert.AreEqual(typeof(short), result.dtype); + } + + [TestMethod] + public void Where_TypePromotion_Int32_UInt32() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new int[] { 1, 2 }); + var y = np.array(new uint[] { 10, 20 }); + var result = np.where(cond, x, y); + + // NumPy: int64 (to accommodate both signed and unsigned 32-bit range) + Assert.AreEqual(typeof(long), result.dtype); + } + + [TestMethod] + public void Where_TypePromotion_Int64_UInt64() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new long[] { 1, 2 }); + var y = np.array(new ulong[] { 10, 20 }); + var result = np.where(cond, x, y); + + // NumPy: float64 (no integer type can hold both int64 and uint64 full range) + Assert.AreEqual(typeof(double), result.dtype); + } + + [TestMethod] + public void Where_TypePromotion_UInt8_Float32() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new byte[] { 1, 2 }); + var y = np.array(new float[] { 10.5f, 20.5f }); + var result = np.where(cond, x, y); + + // NumPy: float32 + Assert.AreEqual(typeof(float), result.dtype); + } + + #endregion + + #region Performance/Stress Tests + + [TestMethod] + public void Where_LargeArray_Performance() + { + var size = 1_000_000; + var cond = np.random.rand(size) > 0.5; + var x = np.ones(size, NPTypeCode.Double); + var y = np.zeros(size, NPTypeCode.Double); + + var sw = System.Diagnostics.Stopwatch.StartNew(); + var result = np.where(cond, x, y); + sw.Stop(); + + Assert.AreEqual(size, result.size); + // Should complete in reasonable time (< 1 second for 1M elements) + Assert.IsTrue(sw.ElapsedMilliseconds < 1000, $"Took {sw.ElapsedMilliseconds}ms"); + } + + [TestMethod] + public void Where_ManyDimensions() + { + // 6D array + var shape = new[] { 2, 3, 2, 2, 2, 3 }; + var cond = np.ones(shape, NPTypeCode.Boolean); + var x = np.ones(shape, NPTypeCode.Int32); + var y = np.zeros(shape, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 3, 2, 2, 2, 3); + Assert.AreEqual(144, result.size); + Assert.AreEqual(144, (long)np.sum(result)); // All 1s + } + + [TestMethod] + public void Where_AllTrue_LargeArray() + { + var size = 10000; + var cond = np.ones(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.zeros(size, NPTypeCode.Int64); + var result = np.where(cond, x, y); + + // Sum of 0 to 9999 = 49995000 + Assert.AreEqual(49995000L, (long)np.sum(result)); + } + + [TestMethod] + public void Where_AllFalse_LargeArray() + { + var size = 10000; + var cond = np.zeros(size, NPTypeCode.Boolean); + var x = np.arange(size); + var y = np.zeros(size, NPTypeCode.Int64); + var result = np.where(cond, x, y); + + Assert.AreEqual(0L, (long)np.sum(result)); + } + + [TestMethod] + public void Where_Alternating_LargeArray() + { + var size = 10000; + var cond = np.zeros(size, NPTypeCode.Boolean); + for (int i = 0; i < size; i += 2) + cond[i] = true; + + var x = np.arange(size); + var y = np.zeros(size, NPTypeCode.Int64); + var result = np.where(cond, x, y); + + // Sum of even indices: 0+2+4+...+9998 = 24995000 + Assert.AreEqual(24995000L, (long)np.sum(result)); + } + + #endregion + + #region Type Conversion Edge Cases + + [TestMethod] + public void Where_UnsignedOverflow() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new byte[] { 255, 0 }); + var y = np.array(new byte[] { 0, 255 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(byte), result.dtype); + Assert.AreEqual((byte)255, (byte)result[0]); + Assert.AreEqual((byte)255, (byte)result[1]); + } + + [TestMethod] + public void Where_Decimal() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new decimal[] { 1.23456789m, 0m }); + var y = np.array(new decimal[] { 0m, 9.87654321m }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(decimal), result.dtype); + Assert.AreEqual(1.23456789m, (decimal)result[0]); + Assert.AreEqual(9.87654321m, (decimal)result[1]); + } + + [TestMethod] + public void Where_Char() + { + var cond = np.array(new[] { true, false, true }); + var x = np.array(new char[] { 'A', 'B', 'C' }); + var y = np.array(new char[] { 'X', 'Y', 'Z' }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(char), result.dtype); + Assert.AreEqual('A', (char)result[0]); + Assert.AreEqual('Y', (char)result[1]); + Assert.AreEqual('C', (char)result[2]); + } + + #endregion + + #region View Behavior + + [TestMethod] + public void Where_ResultIsNewArray_NotView() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { 1, 2 }); + var y = np.array(new[] { 10, 20 }); + var result = np.where(cond, x, y); + + // Modify original, result should not change + x[0] = 999; + Assert.AreEqual(1, (int)result[0], "Result should be independent of x"); + + y[1] = 999; + Assert.AreEqual(20, (int)result[1], "Result should be independent of y"); + } + + [TestMethod] + public void Where_ModifyResult_DoesNotAffectInputs() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new[] { 1, 2 }); + var y = np.array(new[] { 10, 20 }); + var result = np.where(cond, x, y); + + result[0] = 999; + Assert.AreEqual(1, (int)x[0], "x should not be modified"); + Assert.AreEqual(10, (int)y[0], "y should not be modified"); + } + + #endregion + + #region Alternating Patterns + + [TestMethod] + public void Where_Checkerboard_Pattern() + { + // Create checkerboard condition + var cond = np.zeros(new[] { 4, 4 }, NPTypeCode.Boolean); + for (int i = 0; i < 4; i++) + for (int j = 0; j < 4; j++) + cond[i, j] = (i + j) % 2 == 0; + + var x = np.ones(new[] { 4, 4 }, NPTypeCode.Int32); + var y = np.zeros(new[] { 4, 4 }, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + // Verify checkerboard pattern + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(0, (int)result[0, 1]); + Assert.AreEqual(0, (int)result[1, 0]); + Assert.AreEqual(1, (int)result[1, 1]); + } + + [TestMethod] + public void Where_StripedPattern() + { + // Every row alternates between all True and all False + var cond = np.zeros(new[] { 4, 4 }, NPTypeCode.Boolean); + for (int i = 0; i < 4; i++) + for (int j = 0; j < 4; j++) + cond[i, j] = i % 2 == 0; + + var x = np.full(new[] { 4, 4 }, 1); + var y = np.full(new[] { 4, 4 }, 0); + var result = np.where(cond, x, y); + + // Rows 0, 2 should be 1; rows 1, 3 should be 0 + for (int j = 0; j < 4; j++) + { + Assert.AreEqual(1, (int)result[0, j]); + Assert.AreEqual(0, (int)result[1, j]); + Assert.AreEqual(1, (int)result[2, j]); + Assert.AreEqual(0, (int)result[3, j]); + } + } + + #endregion + + #region Empty Array Edge Cases + + [TestMethod] + public void Where_Empty2D() + { + // Empty (0,3) shape + var cond = np.zeros(new[] { 0, 3 }, NPTypeCode.Boolean); + var x = np.zeros(new[] { 0, 3 }, NPTypeCode.Double); + var y = np.zeros(new[] { 0, 3 }, NPTypeCode.Double); + var result = np.where(cond, x, y); + + result.Should().BeShaped(0, 3); + Assert.AreEqual(typeof(double), result.dtype); + } + + [TestMethod] + public void Where_Empty3D() + { + // Empty (2,0,3) shape + var cond = np.zeros(new[] { 2, 0, 3 }, NPTypeCode.Boolean); + var x = np.zeros(new[] { 2, 0, 3 }, NPTypeCode.Int32); + var y = np.zeros(new[] { 2, 0, 3 }, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 0, 3); + Assert.AreEqual(typeof(int), result.dtype); + } + + [TestMethod] + public void Where_SingleArg_Empty2D() + { + var arr = np.zeros(new[] { 0, 3 }, NPTypeCode.Int32); + var result = np.where(arr); + + Assert.AreEqual(2, result.Length); // 2 dimensions + Assert.AreEqual(0, result[0].size); + Assert.AreEqual(0, result[1].size); + } + + #endregion + + #region Error Conditions + + [TestMethod] + public void Where_IncompatibleShapes_ThrowsException() + { + // Shapes (2,) and (3,) cannot be broadcast together + var cond = np.array(new[] { true, false }); // (2,) + var x = np.array(new[] { 1, 2, 3 }); // (3,) + var y = np.array(new[] { 4, 5, 6 }); // (3,) + + Assert.ThrowsException(() => np.where(cond, x, y)); + } + + #endregion + + #region NEP50 Type Promotion (NumPy 2.x Parity) + + /// + /// Verifies NEP50 weak scalar semantics: when a scalar is combined with an array, + /// the array's dtype wins for same-kind operations. + /// + [TestMethod] + public void Where_ScalarTypePromotion_NEP50_WeakScalar() + { + // NumPy 2.x: np.where(cond, 1, uint8_array) -> uint8 (weak scalar) + var cond = np.array(new[] { true, false }); + var yUint8 = np.array(new byte[] { 10, 20 }); + var result = np.where(cond, 1, yUint8); + + // Array dtype wins - matches NumPy 2.x NEP50 + Assert.AreEqual(typeof(byte), result.dtype); + Assert.AreEqual((byte)1, (byte)result[0]); + Assert.AreEqual((byte)20, (byte)result[1]); + } + + /// + /// Two same-type scalars preserve their type. + /// Note: NumPy would return int64 for Python int literals, but C# int32 scalars + /// cannot be distinguished from explicit np.array(1, dtype=int32), so we preserve. + /// + [TestMethod] + public void Where_TwoScalars_SameType_Preserved() + { + var cond = np.array(new[] { true, false }); + + // int + int → int (preserved) + var result = np.where(cond, 1, 0); + Assert.AreEqual(typeof(int), result.dtype); + Assert.AreEqual(1, (int)result[0]); + Assert.AreEqual(0, (int)result[1]); + + // long + long → long (preserved) + result = np.where(cond, 1L, 0L); + Assert.AreEqual(typeof(long), result.dtype); + } + + /// + /// Verifies C# float scalars stay float32 (like np.float32, not Python float). + /// + [TestMethod] + public void Where_TwoScalars_Float32_StaysFloat32() + { + // C# float (1.0f) is like np.float32, not Python's float (which is float64) + // np.where(cond, np.float32(1.0), np.float32(0.0)) -> float32 + var cond = np.array(new[] { true, false }); + var result = np.where(cond, 1.0f, 0.0f); + + Assert.AreEqual(typeof(float), result.dtype); + } + + /// + /// Verifies NEP50: int scalar + float32 array -> float32 (same-kind, array wins). + /// + [TestMethod] + public void Where_IntScalar_Float32Array_ReturnsFloat32() + { + var cond = np.array(new[] { true, false }); + var yFloat32 = np.array(new float[] { 10.5f, 20.5f }); + var result = np.where(cond, 1, yFloat32); + + // Array dtype wins for same-kind (int->float conversion) + Assert.AreEqual(typeof(float), result.dtype); + } + + /// + /// Verifies NEP50: float scalar + int32 array -> float64 (cross-kind promotion). + /// + [TestMethod] + public void Where_FloatScalar_Int32Array_ReturnsFloat64() + { + var cond = np.array(new[] { true, false }); + var yInt32 = np.array(new int[] { 10, 20 }); + var result = np.where(cond, 1.5, yInt32); + + // Cross-kind: float scalar forces float64 + Assert.AreEqual(typeof(double), result.dtype); + } + + #endregion + } +} diff --git a/test/NumSharp.UnitTest/Logic/np.where.Test.cs b/test/NumSharp.UnitTest/Logic/np.where.Test.cs new file mode 100644 index 00000000..b1991bae --- /dev/null +++ b/test/NumSharp.UnitTest/Logic/np.where.Test.cs @@ -0,0 +1,497 @@ +using System; +using System.Linq; +using NumSharp.UnitTest.Utilities; +using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; + +namespace NumSharp.UnitTest.Logic +{ + /// + /// Comprehensive tests for np.where matching NumPy 2.x behavior. + /// + /// NumPy signature: where(condition, x=None, y=None, /) + /// - Single arg: returns np.nonzero(condition) + /// - Three args: element-wise selection with broadcasting + /// + [TestClass] + public class np_where_Test + { + #region Single Argument (nonzero equivalent) + + [TestMethod] + public void Where_SingleArg_1D_ReturnsIndices() + { + // np.where([0, 1, 0, 2, 0, 3]) -> (array([1, 3, 5]),) + var arr = np.array(new[] { 0, 1, 0, 2, 0, 3 }); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + result[0].Should().BeOfValues(1L, 3L, 5L); + } + + [TestMethod] + public void Where_SingleArg_2D_ReturnsTupleOfIndices() + { + // np.where([[0, 1, 0], [2, 0, 3]]) -> (array([0, 1, 1]), array([1, 0, 2])) + var arr = np.array(new int[,] { { 0, 1, 0 }, { 2, 0, 3 } }); + var result = np.where(arr); + + Assert.AreEqual(2, result.Length); + result[0].Should().BeOfValues(0L, 1L, 1L); // row indices + result[1].Should().BeOfValues(1L, 0L, 2L); // col indices + } + + [TestMethod] + public void Where_SingleArg_Boolean_ReturnsNonzero() + { + var arr = np.array(new[] { true, false, true, false, true }); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + result[0].Should().BeOfValues(0L, 2L, 4L); + } + + [TestMethod] + public void Where_SingleArg_Empty_ReturnsEmptyIndices() + { + var arr = np.array(new int[0]); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + Assert.AreEqual(0, result[0].size); + } + + [TestMethod] + public void Where_SingleArg_AllFalse_ReturnsEmptyIndices() + { + var arr = np.array(new[] { false, false, false }); + var result = np.where(arr); + + Assert.AreEqual(1, result.Length); + Assert.AreEqual(0, result[0].size); + } + + [TestMethod] + public void Where_SingleArg_AllTrue_ReturnsAllIndices() + { + var arr = np.array(new[] { true, true, true }); + var result = np.where(arr); + + result[0].Should().BeOfValues(0L, 1L, 2L); + } + + [TestMethod] + public void Where_SingleArg_3D_ReturnsTupleOfThreeArrays() + { + // 2x2x2 array with some non-zero elements + var arr = np.zeros(new[] { 2, 2, 2 }, NPTypeCode.Int32); + arr[0, 0, 1] = 1; + arr[1, 1, 0] = 1; + var result = np.where(arr); + + Assert.AreEqual(3, result.Length); + result[0].Should().BeOfValues(0L, 1L); // dim 0 + result[1].Should().BeOfValues(0L, 1L); // dim 1 + result[2].Should().BeOfValues(1L, 0L); // dim 2 + } + + #endregion + + #region Three Arguments (element-wise selection) + + [TestMethod] + public void Where_ThreeArgs_Basic_SelectsCorrectly() + { + // np.where(a < 5, a, 10*a) for a = arange(10) + var a = np.arange(10); + var result = np.where(a < 5, a, 10 * a); + + result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L); + } + + [TestMethod] + public void Where_ThreeArgs_BooleanCondition() + { + var cond = np.array(new[] { true, false, true, false }); + var x = np.array(new[] { 1, 2, 3, 4 }); + var y = np.array(new[] { 10, 20, 30, 40 }); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(1, 20, 3, 40); + } + + [TestMethod] + public void Where_ThreeArgs_2D() + { + // np.where([[True, False], [True, True]], [[1, 2], [3, 4]], [[9, 8], [7, 6]]) + var cond = np.array(new bool[,] { { true, false }, { true, true } }); + var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }); + var y = np.array(new int[,] { { 9, 8 }, { 7, 6 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 2); + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(8, (int)result[0, 1]); + Assert.AreEqual(3, (int)result[1, 0]); + Assert.AreEqual(4, (int)result[1, 1]); + } + + [TestMethod] + public void Where_ThreeArgs_NonBoolCondition_TreatsAsTruthy() + { + // np.where([0, 1, 2, 0], 100, -100) -> [-100, 100, 100, -100] + var cond = np.array(new[] { 0, 1, 2, 0 }); + var result = np.where(cond, 100, -100); + + result.Should().BeOfValues(-100, 100, 100, -100); + } + + #endregion + + #region Scalar Arguments + + [TestMethod] + public void Where_ScalarX() + { + var cond = np.array(new[] { true, false, true, false }); + var y = np.array(new[] { 10, 20, 30, 40 }); + var result = np.where(cond, 99, y); + + result.Should().BeOfValues(99, 20, 99, 40); + } + + [TestMethod] + public void Where_ScalarY() + { + var cond = np.array(new[] { true, false, true, false }); + var x = np.array(new[] { 1, 2, 3, 4 }); + var result = np.where(cond, x, -1); + + result.Should().BeOfValues(1, -1, 3, -1); + } + + [TestMethod] + public void Where_BothScalars() + { + var cond = np.array(new[] { true, false, true, false }); + var result = np.where(cond, 1, 0); + + result.Should().BeOfValues(1, 0, 1, 0); + } + + [TestMethod] + public void Where_ScalarFloat() + { + var cond = np.array(new[] { true, false }); + var result = np.where(cond, 1.5, 2.5); + + Assert.AreEqual(typeof(double), result.dtype); + Assert.AreEqual(1.5, (double)result[0], 1e-10); + Assert.AreEqual(2.5, (double)result[1], 1e-10); + } + + #endregion + + #region Broadcasting + + [TestMethod] + public void Where_Broadcasting_ScalarY() + { + // np.where(a < 4, a, -1) for 3x3 array + var arr = np.array(new int[,] { { 0, 1, 2 }, { 0, 2, 4 }, { 0, 3, 6 } }); + var result = np.where(arr < 4, arr, -1); + + result.Should().BeShaped(3, 3); + Assert.AreEqual(0, (int)result[0, 0]); + Assert.AreEqual(1, (int)result[0, 1]); + Assert.AreEqual(2, (int)result[0, 2]); + Assert.AreEqual(-1, (int)result[1, 2]); + Assert.AreEqual(-1, (int)result[2, 2]); + } + + [TestMethod] + public void Where_Broadcasting_DifferentShapes() + { + // cond: (2,1), x: (3,), y: (1,3) -> result: (2,3) + var cond = np.array(new bool[,] { { true }, { false } }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new int[,] { { 10, 20, 30 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(2, 3); + // Row 0: cond=True, so x values + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(2, (int)result[0, 1]); + Assert.AreEqual(3, (int)result[0, 2]); + // Row 1: cond=False, so y values + Assert.AreEqual(10, (int)result[1, 0]); + Assert.AreEqual(20, (int)result[1, 1]); + Assert.AreEqual(30, (int)result[1, 2]); + } + + [TestMethod] + public void Where_Broadcasting_ColumnVector() + { + // cond: (3,1), x: scalar, y: (1,4) -> result: (3,4) + var cond = np.array(new bool[,] { { true }, { false }, { true } }); + var x = 1; + var y = np.array(new int[,] { { 10, 20, 30, 40 } }); + var result = np.where(cond, x, y); + + result.Should().BeShaped(3, 4); + // Row 0: all 1s + for (int j = 0; j < 4; j++) + Assert.AreEqual(1, (int)result[0, j]); + // Row 1: y values + Assert.AreEqual(10, (int)result[1, 0]); + Assert.AreEqual(40, (int)result[1, 3]); + // Row 2: all 1s + for (int j = 0; j < 4; j++) + Assert.AreEqual(1, (int)result[2, j]); + } + + #endregion + + #region Type Promotion + + [TestMethod] + public void Where_TypePromotion_IntFloat_ReturnsFloat64() + { + var cond = np.array(new[] { true, false }); + var result = np.where(cond, 1, 2.5); + + Assert.AreEqual(typeof(double), result.dtype); + Assert.AreEqual(1.0, (double)result[0], 1e-10); + Assert.AreEqual(2.5, (double)result[1], 1e-10); + } + + [TestMethod] + public void Where_TypePromotion_Int32Int64_ReturnsInt64() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new int[] { 1 }); + var y = np.array(new long[] { 2 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(long), result.dtype); + } + + [TestMethod] + public void Where_TypePromotion_FloatDouble_ReturnsDouble() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new float[] { 1.5f }); + var y = np.array(new double[] { 2.5 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(double), result.dtype); + } + + #endregion + + #region Edge Cases + + [TestMethod] + public void Where_EmptyArrays_ThreeArgs() + { + var cond = np.array(new bool[0]); + var x = np.array(new int[0]); + var y = np.array(new int[0]); + var result = np.where(cond, x, y); + + Assert.AreEqual(0, result.size); + } + + [TestMethod] + public void Where_SingleElement() + { + var cond = np.array(new[] { true }); + var result = np.where(cond, 42, 0); + + Assert.AreEqual(1, result.size); + Assert.AreEqual(typeof(int), result.dtype); // same-type scalars preserve type + Assert.AreEqual(42, (int)result[0]); + } + + [TestMethod] + public void Where_AllTrue_ReturnsAllX() + { + var cond = np.array(new[] { true, true, true }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(1, 2, 3); + } + + [TestMethod] + public void Where_AllFalse_ReturnsAllY() + { + var cond = np.array(new[] { false, false, false }); + var x = np.array(new[] { 1, 2, 3 }); + var y = np.array(new[] { 10, 20, 30 }); + var result = np.where(cond, x, y); + + result.Should().BeOfValues(10, 20, 30); + } + + [TestMethod] + public void Where_LargeArray() + { + var size = 100000; + var cond = np.arange(size) % 2 == 0; // alternating True/False + var x = np.ones(size, NPTypeCode.Int32); + var y = np.zeros(size, NPTypeCode.Int32); + var result = np.where(cond, x, y); + + Assert.AreEqual(size, result.size); + // Even indices should be 1, odd should be 0 + Assert.AreEqual(1, (int)result[0]); + Assert.AreEqual(0, (int)result[1]); + Assert.AreEqual(1, (int)result[2]); + } + + #endregion + + #region NumPy Output Verification + + [TestMethod] + public void Where_NumPyExample1() + { + // From NumPy docs: np.where([[True, False], [True, True]], + // [[1, 2], [3, 4]], [[9, 8], [7, 6]]) + // Expected: array([[1, 8], [3, 4]]) + var cond = np.array(new bool[,] { { true, false }, { true, true } }); + var x = np.array(new int[,] { { 1, 2 }, { 3, 4 } }); + var y = np.array(new int[,] { { 9, 8 }, { 7, 6 } }); + var result = np.where(cond, x, y); + + Assert.AreEqual(1, (int)result[0, 0]); + Assert.AreEqual(8, (int)result[0, 1]); + Assert.AreEqual(3, (int)result[1, 0]); + Assert.AreEqual(4, (int)result[1, 1]); + } + + [TestMethod] + public void Where_NumPyExample2() + { + // From NumPy docs: np.where(a < 5, a, 10*a) for a = arange(10) + // Expected: array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90]) + var a = np.arange(10); + var result = np.where(a < 5, a, 10 * a); + + result.Should().BeOfValues(0L, 1L, 2L, 3L, 4L, 50L, 60L, 70L, 80L, 90L); + } + + [TestMethod] + public void Where_NumPyExample3() + { + // From NumPy docs: np.where(a < 4, a, -1) for specific array + // Expected: array([[ 0, 1, 2], [ 0, 2, -1], [ 0, 3, -1]]) + var a = np.array(new int[,] { { 0, 1, 2 }, { 0, 2, 4 }, { 0, 3, 6 } }); + var result = np.where(a < 4, a, -1); + + Assert.AreEqual(0, (int)result[0, 0]); + Assert.AreEqual(1, (int)result[0, 1]); + Assert.AreEqual(2, (int)result[0, 2]); + Assert.AreEqual(0, (int)result[1, 0]); + Assert.AreEqual(2, (int)result[1, 1]); + Assert.AreEqual(-1, (int)result[1, 2]); + Assert.AreEqual(0, (int)result[2, 0]); + Assert.AreEqual(3, (int)result[2, 1]); + Assert.AreEqual(-1, (int)result[2, 2]); + } + + #endregion + + #region Dtype Coverage + + [TestMethod] + public void Where_Dtype_Byte() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new byte[] { 1, 2 }); + var y = np.array(new byte[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(byte), result.dtype); + result.Should().BeOfValues((byte)1, (byte)20); + } + + [TestMethod] + public void Where_Dtype_Int16() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new short[] { 1, 2 }); + var y = np.array(new short[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(short), result.dtype); + result.Should().BeOfValues((short)1, (short)20); + } + + [TestMethod] + public void Where_Dtype_Int32() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new int[] { 1, 2 }); + var y = np.array(new int[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(int), result.dtype); + result.Should().BeOfValues(1, 20); + } + + [TestMethod] + public void Where_Dtype_Int64() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new long[] { 1, 2 }); + var y = np.array(new long[] { 10, 20 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(long), result.dtype); + result.Should().BeOfValues(1L, 20L); + } + + [TestMethod] + public void Where_Dtype_Single() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new float[] { 1.5f, 2.5f }); + var y = np.array(new float[] { 10.5f, 20.5f }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(float), result.dtype); + Assert.AreEqual(1.5f, (float)result[0], 1e-6f); + Assert.AreEqual(20.5f, (float)result[1], 1e-6f); + } + + [TestMethod] + public void Where_Dtype_Double() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new double[] { 1.5, 2.5 }); + var y = np.array(new double[] { 10.5, 20.5 }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(double), result.dtype); + Assert.AreEqual(1.5, (double)result[0], 1e-10); + Assert.AreEqual(20.5, (double)result[1], 1e-10); + } + + [TestMethod] + public void Where_Dtype_Boolean() + { + var cond = np.array(new[] { true, false }); + var x = np.array(new bool[] { true, true }); + var y = np.array(new bool[] { false, false }); + var result = np.where(cond, x, y); + + Assert.AreEqual(typeof(bool), result.dtype); + Assert.IsTrue((bool)result[0]); + Assert.IsFalse((bool)result[1]); + } + + #endregion + } +}