Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,9 @@ def _flat_size(subary: ArrayOrContainerOrScalar) -> Array | int | np.integer:
try:
iterable = serialize_container(subary)
except NotAnArrayContainerError:
assert not is_array_container(subary)
assert not is_scalar_like(subary)
if TYPE_CHECKING:
assert not is_array_container(subary)
assert not is_scalar_like(subary)

if common_dtype is None:
common_dtype = subary.dtype
Expand Down
4 changes: 2 additions & 2 deletions arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def inner_bcast(ary: ArrayOrScalar) -> ArrayOrScalar:
if is_scalar_like(ary):
return ary
else:
assert isinstance(ary, np.ndarray)
assert isinstance(ary, jnp.ndarray)
return cast("Array", cast("object", jnp.broadcast_to(ary, shape)))

return rec_map_container(inner_bcast, array)
Expand All @@ -145,7 +145,7 @@ def vdot(self, a, b, dtype=None):
from arraycontext import rec_multimap_reduce_array_container

def _rec_vdot(ary1, ary2):
common_dtype = np.result_type(ary1, ary2)
common_dtype = jnp.result_type(ary1, ary2)
if dtype not in (None, common_dtype):
raise NotImplementedError(
f"{type(self).__name__} cannot take dtype in vdot.")
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
import pytato as pt
dag = pt.tag_all_calls_to_be_inlined(dag)
dag = pt.inline_calls(dag)
return pt.transform.materialize_with_mpms(dag)
return pt.materialize_with_mpms(dag)

@override
def einsum(self, spec, *args, arg_names=None, tagged=()):
Expand Down
2 changes: 1 addition & 1 deletion test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def test_array_context_csr_matmul(actx_factory: ArrayContextFactory):
momentum=obj_array.new_1d([x] * 3),
enthalpy=x)

elem_values = actx.zeros((n//2,), dtype=np.float64) + 1.
elem_values = actx.np.zeros((n//2,), dtype=np.float64) + 1.
elem_col_indices = actx.from_numpy(2*np.arange(n//2, dtype=np.int32))
row_starts = actx.from_numpy(np.arange(n//2 + 1, dtype=np.int32))

Expand Down
Loading