diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index ff64e643..c2c05bf5 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -1095,8 +1095,9 @@ def _flat_size(subary: ArrayOrContainerOrScalar) -> Array | int | np.integer: try: iterable = serialize_container(subary) except NotAnArrayContainerError: - assert not is_array_container(subary) - assert not is_scalar_like(subary) + if TYPE_CHECKING: + assert not is_array_container(subary) + assert not is_scalar_like(subary) if common_dtype is None: common_dtype = subary.dtype diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 3ae7f9c6..b96607e5 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -123,7 +123,7 @@ def inner_bcast(ary: ArrayOrScalar) -> ArrayOrScalar: if is_scalar_like(ary): return ary else: - assert isinstance(ary, np.ndarray) + assert isinstance(ary, jnp.ndarray) return cast("Array", cast("object", jnp.broadcast_to(ary, shape))) return rec_map_container(inner_bcast, array) @@ -145,7 +145,7 @@ def vdot(self, a, b, dtype=None): from arraycontext import rec_multimap_reduce_array_container def _rec_vdot(ary1, ary2): - common_dtype = np.result_type(ary1, ary2) + common_dtype = jnp.result_type(ary1, ary2) if dtype not in (None, common_dtype): raise NotImplementedError( f"{type(self).__name__} cannot take dtype in vdot.") diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 7c667a79..a8ec6def 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -829,7 +829,7 @@ def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays import pytato as pt dag = pt.tag_all_calls_to_be_inlined(dag) dag = pt.inline_calls(dag) - return pt.transform.materialize_with_mpms(dag) + return pt.materialize_with_mpms(dag) @override def einsum(self, spec, *args, arg_names=None, tagged=()): diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index de143985..e5336cc2 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -676,7 +676,7 @@ def test_array_context_csr_matmul(actx_factory: ArrayContextFactory): momentum=obj_array.new_1d([x] * 3), enthalpy=x) - elem_values = actx.zeros((n//2,), dtype=np.float64) + 1. + elem_values = actx.np.zeros((n//2,), dtype=np.float64) + 1. elem_col_indices = actx.from_numpy(2*np.arange(n//2, dtype=np.int32)) row_starts = actx.from_numpy(np.arange(n//2 + 1, dtype=np.int32))