Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ This release is compatible with NumPy 2.5.
* Cleaned up Python bindings for indexing functions, renaming `usm_ndarray_take` and `usm_ndarray_put` to `py_take` and `py_put` and refactoring validation [#2935](https://github.com/IntelPython/dpnp/pull/2935)
* Updated `dpnp.linalg.eig` and `dpnp.linalg.eigvals` documentation to reflect NumPy's always-complex eigenvalue output for general matrices [#2953](https://github.com/IntelPython/dpnp/pull/2953)
* Clarified support for negative axes in `dpnp.transpose`/`dpnp.permute_dims` documentation [#2940](https://github.com/IntelPython/dpnp/pull/2940)
* Allowed `dpnp.take` and `dpnp.compress` to cast the result into an `out` array of a different but same-kind dtype [#2959](https://github.com/IntelPython/dpnp/pull/2959)

### Deprecated

Expand Down
22 changes: 13 additions & 9 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
raise IndexError("cannot take non-empty indices from an empty axis")
res_sh = x_sh[:axis] + inds.shape + x_sh[axis_end:]

x_dt = x.dtype
if out is not None:
out = dpnp.get_usm_ndarray(out)

Expand All @@ -288,21 +289,24 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
f"Expected output shape is {res_sh}, got {out.shape}"
)

if x.dtype != out.dtype:
raise TypeError(
f"Output array of type {x.dtype} is needed, " f"got {out.dtype}"
)

if dpt.get_execution_queue((q, out.sycl_queue)) is None:
raise ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)

if ti._array_overlap(x, out):
if x_dt != out.dtype:
if not dpnp.can_cast(x_dt, out.dtype, casting="same_kind"):
raise TypeError(
f"Output array of type {x_dt} is needed, got {out.dtype}"
)

# tensor.take() requires `out` to match the input dtype
out = dpt.empty_like(out, dtype=x_dt)
elif ti._array_overlap(x, out):
# Allocate a temporary buffer to avoid memory overlapping.
out = dpt.empty_like(out)
else:
out = dpt.empty(res_sh, dtype=x.dtype, usm_type=usm_type, sycl_queue=q)
out = dpt.empty(res_sh, dtype=x_dt, usm_type=usm_type, sycl_queue=q)

_manager = dpu.SequentialOrderManager[q]
dep_evs = _manager.submitted_events
Expand Down Expand Up @@ -419,7 +423,7 @@ def compress(condition, a, axis=None, out=None):

res = _take_index(a_ary, inds[0], axis, exec_q, res_usm_type, out=out)

return dpnp.get_result_array(res, out=out)
return dpnp.get_result_array(res, out=out, casting="same_kind")


def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):
Expand Down Expand Up @@ -2170,7 +2174,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
usm_a, usm_ind, axis, exec_q, res_usm_type, out=out, mode=mode
)

return dpnp.get_result_array(usm_res, out=out)
return dpnp.get_result_array(usm_res, out=out, casting="same_kind")


def take_along_axis(a, indices, axis=-1, mode="wrap"):
Expand Down
65 changes: 62 additions & 3 deletions dpnp/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,45 @@ def test_axis_as_array(self):
assert_raises(TypeError, ia.take, [0], axis=ia)
assert_raises(TypeError, a.take, [0], axis=a)

@testing.with_requires("numpy>=2.5")
@pytest.mark.usefixtures(
"suppress_complex_warning", "suppress_invalid_numpy_warnings"
)
@pytest.mark.parametrize("out_dt", get_all_dtypes(no_none=True))
def test_out_dtype(self, out_dt):
a = numpy.array([1, 2, 3, 4], dtype="i4")
ind = numpy.array([0, 2, 3])
out = numpy.empty(3, dtype=out_dt)
ia, iind, iout = dpnp.array(a), dpnp.array(ind), dpnp.array(out)

if numpy.can_cast(a.dtype, out_dt, casting="same_kind"):
# casting the result into `out` of a same-kind dtype is allowed
result = ia.take(iind, out=iout)
expected = a.take(ind, out=out)
assert result is iout
assert_array_equal(result, expected)
else:
# NumPy only deprecates casting the result into `out` of a
# different kind, while dpnp does not allow it
with pytest.warns(DeprecationWarning, match="casting of output"):
a.take(ind, out=out)

with pytest.raises(TypeError, match="Output array"):
ia.take(iind, out=iout)

def test_overlapping_out(self):
a = numpy.arange(6)
ind = numpy.array([0, 1])
ia, iind = dpnp.array(a), dpnp.array(ind)

iout = ia[2:4]
result = dpnp.take(ia, iind, out=iout)
assert result is iout
assert (ia[2:4] == iout).all()

expected = numpy.take(a, ind)
assert_array_equal(expected, result)

def test_mode_raise(self):
a = dpnp.array([[1, 2], [3, 4]])
assert_raises(ValueError, a.take, [-1, 4], mode="raise")
Expand Down Expand Up @@ -1565,14 +1604,34 @@ def test_compress_invalid_out_errors(self):
out_bad_queue = dpnp.empty(1, dtype="i4", sycl_queue=q2)
with pytest.raises(ExecutionPlacementError):
dpnp.compress(condition, a, out=out_bad_queue)
out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1)
with pytest.raises(TypeError):
dpnp.compress(condition, a, out=out_bad_dt)
out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1)
out_read_only.flags.writable = False
with pytest.raises(ValueError):
dpnp.compress(condition, a, out=out_read_only)

@testing.with_requires("numpy>=2.5")
@pytest.mark.usefixtures("suppress_complex_warning")
@pytest.mark.parametrize("out_dt", get_all_dtypes(no_none=True))
def test_out_dtype(self, out_dt):
a = numpy.array([[1, 2], [3, 4]], dtype="i4")
out = numpy.empty((2, 1), dtype=out_dt)
ia, iout = dpnp.array(a), dpnp.array(out)

if numpy.can_cast(a.dtype, out_dt, casting="same_kind"):
# casting the result into `out` of a same-kind dtype is allowed
result = dpnp.compress([1, 0], ia, axis=1, out=iout)
expected = numpy.compress([1, 0], a, axis=1, out=out)
assert result is iout
assert_array_equal(result, expected)
else:
# NumPy only deprecates casting the result into `out` of a
# different kind, while dpnp does not allow it
with pytest.warns(DeprecationWarning, match="casting of output"):
numpy.compress([1, 0], a, axis=1, out=out)

with pytest.raises(TypeError, match="Output array"):
dpnp.compress([1, 0], ia, axis=1, out=iout)

def test_compress_empty_axis(self):
a = dpnp.ones((10, 0, 5), dtype="i4")
condition = [True, False, True]
Expand Down
15 changes: 0 additions & 15 deletions dpnp/tests/third_party/cupy/core_tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,21 +558,6 @@ def test_shape_mismatch(self):
wrap_take(a, i, out=o)


@testing.parameterize(
{"shape": (3, 4, 5), "indices": (2, 3), "out_shape": (2, 3)},
{"shape": (), "indices": (), "out_shape": ()},
)
class TestNdarrayTakeErrorTypeMismatch(unittest.TestCase):

def test_output_type_mismatch(self):
for xp in (numpy, cupy):
a = testing.shaped_arange(self.shape, xp, numpy.int32)
i = testing.shaped_arange(self.indices, xp, numpy.int32) % 3
o = testing.shaped_arange(self.out_shape, xp, numpy.float32)
with pytest.raises(TypeError):
wrap_take(a, i, out=o)


@testing.parameterize(
{"shape": (0,), "indices": (0,), "axis": None},
{"shape": (0,), "indices": (0, 1), "axis": None},
Expand Down
Loading