diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b92406ad7b..2267827dc04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ This release is compatible with NumPy 2.5. * Cleaned up Python bindings for indexing functions, renaming `usm_ndarray_take` and `usm_ndarray_put` to `py_take` and `py_put` and refactoring validation [#2935](https://github.com/IntelPython/dpnp/pull/2935) * Updated `dpnp.linalg.eig` and `dpnp.linalg.eigvals` documentation to reflect NumPy's always-complex eigenvalue output for general matrices [#2953](https://github.com/IntelPython/dpnp/pull/2953) * Clarified support for negative axes in `dpnp.transpose`/`dpnp.permute_dims` documentation [#2940](https://github.com/IntelPython/dpnp/pull/2940) +* Allowed `dpnp.take` and `dpnp.compress` to cast the result into an `out` array of a different but same-kind dtype [#2959](https://github.com/IntelPython/dpnp/pull/2959) ### Deprecated diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index 90e76a6c9e9..2bf1594e0f5 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -276,6 +276,7 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0): raise IndexError("cannot take non-empty indices from an empty axis") res_sh = x_sh[:axis] + inds.shape + x_sh[axis_end:] + x_dt = x.dtype if out is not None: out = dpnp.get_usm_ndarray(out) @@ -288,21 +289,24 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0): f"Expected output shape is {res_sh}, got {out.shape}" ) - if x.dtype != out.dtype: - raise TypeError( - f"Output array of type {x.dtype} is needed, " f"got {out.dtype}" - ) - if dpt.get_execution_queue((q, out.sycl_queue)) is None: raise ExecutionPlacementError( "Input and output allocation queues are not compatible" ) - if ti._array_overlap(x, out): + if x_dt != out.dtype: + if not dpnp.can_cast(x_dt, out.dtype, casting="same_kind"): + raise TypeError( + f"Output array of type {x_dt} is needed, got {out.dtype}" + ) + + # tensor.take() requires `out` to match the input dtype + out = dpt.empty_like(out, dtype=x_dt) + elif ti._array_overlap(x, out): # Allocate a temporary buffer to avoid memory overlapping. out = dpt.empty_like(out) else: - out = dpt.empty(res_sh, dtype=x.dtype, usm_type=usm_type, sycl_queue=q) + out = dpt.empty(res_sh, dtype=x_dt, usm_type=usm_type, sycl_queue=q) _manager = dpu.SequentialOrderManager[q] dep_evs = _manager.submitted_events @@ -419,7 +423,7 @@ def compress(condition, a, axis=None, out=None): res = _take_index(a_ary, inds[0], axis, exec_q, res_usm_type, out=out) - return dpnp.get_result_array(res, out=out) + return dpnp.get_result_array(res, out=out, casting="same_kind") def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None): @@ -2170,7 +2174,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"): usm_a, usm_ind, axis, exec_q, res_usm_type, out=out, mode=mode ) - return dpnp.get_result_array(usm_res, out=out) + return dpnp.get_result_array(usm_res, out=out, casting="same_kind") def take_along_axis(a, indices, axis=-1, mode="wrap"): diff --git a/dpnp/tests/test_indexing.py b/dpnp/tests/test_indexing.py index 0331e7151f0..b7dc5f791b7 100644 --- a/dpnp/tests/test_indexing.py +++ b/dpnp/tests/test_indexing.py @@ -952,6 +952,45 @@ def test_axis_as_array(self): assert_raises(TypeError, ia.take, [0], axis=ia) assert_raises(TypeError, a.take, [0], axis=a) + @testing.with_requires("numpy>=2.5") + @pytest.mark.usefixtures( + "suppress_complex_warning", "suppress_invalid_numpy_warnings" + ) + @pytest.mark.parametrize("out_dt", get_all_dtypes(no_none=True)) + def test_out_dtype(self, out_dt): + a = numpy.array([1, 2, 3, 4], dtype="i4") + ind = numpy.array([0, 2, 3]) + out = numpy.empty(3, dtype=out_dt) + ia, iind, iout = dpnp.array(a), dpnp.array(ind), dpnp.array(out) + + if numpy.can_cast(a.dtype, out_dt, casting="same_kind"): + # casting the result into `out` of a same-kind dtype is allowed + result = ia.take(iind, out=iout) + expected = a.take(ind, out=out) + assert result is iout + assert_array_equal(result, expected) + else: + # NumPy only deprecates casting the result into `out` of a + # different kind, while dpnp does not allow it + with pytest.warns(DeprecationWarning, match="casting of output"): + a.take(ind, out=out) + + with pytest.raises(TypeError, match="Output array"): + ia.take(iind, out=iout) + + def test_overlapping_out(self): + a = numpy.arange(6) + ind = numpy.array([0, 1]) + ia, iind = dpnp.array(a), dpnp.array(ind) + + iout = ia[2:4] + result = dpnp.take(ia, iind, out=iout) + assert result is iout + assert (ia[2:4] == iout).all() + + expected = numpy.take(a, ind) + assert_array_equal(expected, result) + def test_mode_raise(self): a = dpnp.array([[1, 2], [3, 4]]) assert_raises(ValueError, a.take, [-1, 4], mode="raise") @@ -1565,14 +1604,34 @@ def test_compress_invalid_out_errors(self): out_bad_queue = dpnp.empty(1, dtype="i4", sycl_queue=q2) with pytest.raises(ExecutionPlacementError): dpnp.compress(condition, a, out=out_bad_queue) - out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1) - with pytest.raises(TypeError): - dpnp.compress(condition, a, out=out_bad_dt) out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1) out_read_only.flags.writable = False with pytest.raises(ValueError): dpnp.compress(condition, a, out=out_read_only) + @testing.with_requires("numpy>=2.5") + @pytest.mark.usefixtures("suppress_complex_warning") + @pytest.mark.parametrize("out_dt", get_all_dtypes(no_none=True)) + def test_out_dtype(self, out_dt): + a = numpy.array([[1, 2], [3, 4]], dtype="i4") + out = numpy.empty((2, 1), dtype=out_dt) + ia, iout = dpnp.array(a), dpnp.array(out) + + if numpy.can_cast(a.dtype, out_dt, casting="same_kind"): + # casting the result into `out` of a same-kind dtype is allowed + result = dpnp.compress([1, 0], ia, axis=1, out=iout) + expected = numpy.compress([1, 0], a, axis=1, out=out) + assert result is iout + assert_array_equal(result, expected) + else: + # NumPy only deprecates casting the result into `out` of a + # different kind, while dpnp does not allow it + with pytest.warns(DeprecationWarning, match="casting of output"): + numpy.compress([1, 0], a, axis=1, out=out) + + with pytest.raises(TypeError, match="Output array"): + dpnp.compress([1, 0], ia, axis=1, out=iout) + def test_compress_empty_axis(self): a = dpnp.ones((10, 0, 5), dtype="i4") condition = [True, False, True] diff --git a/dpnp/tests/third_party/cupy/core_tests/test_ndarray.py b/dpnp/tests/third_party/cupy/core_tests/test_ndarray.py index ac6073a3098..cbb8a5c818d 100644 --- a/dpnp/tests/third_party/cupy/core_tests/test_ndarray.py +++ b/dpnp/tests/third_party/cupy/core_tests/test_ndarray.py @@ -558,21 +558,6 @@ def test_shape_mismatch(self): wrap_take(a, i, out=o) -@testing.parameterize( - {"shape": (3, 4, 5), "indices": (2, 3), "out_shape": (2, 3)}, - {"shape": (), "indices": (), "out_shape": ()}, -) -class TestNdarrayTakeErrorTypeMismatch(unittest.TestCase): - - def test_output_type_mismatch(self): - for xp in (numpy, cupy): - a = testing.shaped_arange(self.shape, xp, numpy.int32) - i = testing.shaped_arange(self.indices, xp, numpy.int32) % 3 - o = testing.shaped_arange(self.out_shape, xp, numpy.float32) - with pytest.raises(TypeError): - wrap_take(a, i, out=o) - - @testing.parameterize( {"shape": (0,), "indices": (0,), "axis": None}, {"shape": (0,), "indices": (0, 1), "axis": None},