Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytential/symbolic/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def map_interpolation(self, expr: pp.Interpolation):
else:
raise TypeError(f"cannot interpolate '{type(operand).__name__}'")

def map_bremer_weighted_density(self, expr: pp.BremerWeightedDensity):
return self.rec(expr.operand)

def map_interleave(self, expr: pp.Interleave):
return interleave_dof_arrays(
self.places.get_discretization(
Expand Down
92 changes: 82 additions & 10 deletions pytential/symbolic/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ def map_interpolation(self, expr: pp.Interpolation):

return type(expr)(expr.from_dd, expr.to_dd, operand)

def map_bremer_weighted_density(self, expr: pp.BremerWeightedDensity):
operand = self.rec_arith(expr.operand)
if operand is expr.operand:
return expr

return type(expr)(operand)

def map_interleave(self, expr: pp.Interleave):
operand_1 = self.rec_arith(expr.operand_1)
operand_2 = self.rec_arith(expr.operand_2)
Expand Down Expand Up @@ -227,6 +234,7 @@ def _map_with_operand(self, expr:
| pp.ElementwiseMin
| pp.ElementwiseMax
| pp.Interpolation
| pp.BremerWeightedDensity
):
return self.rec(expr.operand)

Expand All @@ -242,6 +250,8 @@ def _map_with_operand(self, expr:
map_elementwise_max: \
Callable[[Self, pp.ElementwiseMax], ResultT] = _map_with_operand
map_interpolation: Callable[[Self, pp.Interpolation], ResultT] = _map_with_operand
map_bremer_weighted_density: \
Callable[[Self, pp.BremerWeightedDensity], ResultT] = _map_with_operand

def map_interleave(self, expr: pp.Interleave):
return self.combine([self.rec(expr.operand_1), self.rec(expr.operand_2)])
Expand Down Expand Up @@ -353,6 +363,13 @@ def map_num_reference_derivative(self, expr: pp.NumReferenceDerivative):

return type(expr)(expr.ref_axes, operand, expr.dofdesc)

def map_bremer_weighted_density(self, expr: pp.BremerWeightedDensity):
operand = self.rec(expr.operand)
if operand is expr.operand:
return expr

return type(expr)(operand)

def map_int_g(self, expr: pp.IntG):
densities, kernel_arguments, changed = rec_int_g_arguments(self, expr)
if not changed:
Expand All @@ -367,9 +384,7 @@ def map_common_subexpression(self, expr: p.CommonSubexpression):
return expr

return pp.cse(
child,
expr.prefix,
expr.scope)
child, expr.prefix, expr.scope)

# }}}

Expand Down Expand Up @@ -758,10 +773,37 @@ class EarlyInterpolationAdder(
"""
from_dd: DOFDescriptor
to_dd: DOFDescriptor
variable_from_dd: DOFDescriptor | None = None

@override
def map_variable(self, expr: p.Variable):
return pp.interpolate(expr, self.from_dd, self.to_dd)
from_dd = self.from_dd
if self.variable_from_dd is not None:
from_dd = self.variable_from_dd
return pp.interpolate(expr, from_dd, self.to_dd)

@override
def map_subscript(self, expr: p.Subscript):
if isinstance(expr.aggregate, p.Variable):
from_dd = self.from_dd
if self.variable_from_dd is not None:
from_dd = self.variable_from_dd
return pp.interpolate(expr, from_dd, self.to_dd)

return super().map_subscript(expr)

def map_q_weight(self, expr: pp.QWeight):
raise ValueError(
"EarlyInterpolationAdder reached a bare QWeight.")

def map_bremer_weighted_density(
self,
expr: pp.BremerWeightedDensity,
) -> Expression:
from_dd = self.from_dd
if self.variable_from_dd is not None:
from_dd = self.variable_from_dd
return pp.interpolate(expr, from_dd, self.to_dd)

@override
def map_call(self,
Expand All @@ -778,6 +820,12 @@ def map_call(self,
def handle_unsupported_expression(self, expr: p.ExpressionNode) -> Expression:
return pp.interpolate(expr, self.from_dd, self.to_dd)

@override
def map_common_subexpression(self,
expr: p.CommonSubexpression, /,
) -> Expression:
return CSECachingMapperMixin.map_common_subexpression(self, expr)

@override
def map_common_subexpression_uncached(self,
expr: p.CommonSubexpression, /,
Expand Down Expand Up @@ -852,15 +900,21 @@ def map_int_g(self, expr: pp.IntG):
if not isinstance(lpot_source, QBXLayerPotentialSource):
return expr

from_dd = expr.source.to_stage1()
to_dd = from_dd.to_quad_stage2()
interp_adder = EarlyInterpolationAdder(from_dd, to_dd)
variable_from_dd = expr.source.to_stage1()
to_dd = variable_from_dd.to_quad_stage2()

# stage1 density discretization for geometry can give wrong values for
# quantities that depend on the stage2 element parameterization, such as
# area_element.
geometry_from_dd = variable_from_dd.copy(discr_stage=self.from_discr_stage)

density_interp_adder = EarlyInterpolationAdder(
geometry_from_dd, to_dd, variable_from_dd=variable_from_dd)
densities = tuple(
interp_adder.rec_arith(self.rec_arith(density))
density_interp_adder.rec_arith(self.rec_arith(density))
for density in expr.densities)

from_dd = from_dd.copy(discr_stage=self.from_discr_stage)
interp_adder = EarlyInterpolationAdder(from_dd, to_dd)
interp_adder = EarlyInterpolationAdder(geometry_from_dd, to_dd)
kernel_arguments = constantdict({
name: componentwise(
lambda aexpr: interp_adder.rec_arith(
Expand Down Expand Up @@ -1049,6 +1103,14 @@ def map_interpolation(self, expr: pp.Interpolation, enclosing_prec: int):
stringify_where(expr.to_dd),
self.rec(expr.operand, PREC_NONE))

def map_bremer_weighted_density(
self,
expr: pp.BremerWeightedDensity,
enclosing_prec: int,
):
return "BremerWeightedDensity({})".format(
self.rec(expr.operand, PREC_NONE))

def map_interleave(self, expr: pp.Interleave, enclosing_prec: int):
return "Interleave[{}]({}, {})".format(
stringify_where(expr.from_dd),
Expand Down Expand Up @@ -1107,6 +1169,16 @@ def map_map_node_sum(self, expr: pp.NodeSum):

map_q_weight = map_pytential_leaf

def map_bremer_weighted_density(self, expr: pp.BremerWeightedDensity):
self.lines.append(
'{} [label="BremerWeightedDensity",shape=circle];'.format(
self.get_id(expr)))
if not self.visit(expr, node_printed=True):
return

self.rec(expr.operand)
self.post_visit(expr)

def map_int_g(self, expr: pp.IntG):
descr = "Int[%s->%s]@(%d) (%s)" % (
stringify_where(expr.source),
Expand Down
22 changes: 16 additions & 6 deletions pytential/symbolic/pde/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class L2WeightedPDEOperator(ABC):
.. automethod:: get_sqrt_weight

.. automethod:: get_density_var
.. automethod:: get_weighted_density
.. automethod:: prepare_rhs

.. automethod:: representation
Expand Down Expand Up @@ -155,6 +156,17 @@ def get_density_var(self, name: str) -> sym.var:
"""
return sym.var(name)

def get_weighted_density(
self,
u: ArithmeticExpression,
dofdesc: DOFDescriptorLike = None,
) -> ArithmeticExpression:
if not self.use_l2_weighting:
return sym.cse(u)

return sym.cse(sym.bremer_weighted_density(
u / self.get_sqrt_weight(dofdesc)))

@abstractmethod
def representation(self,
u: ArithmeticExpression,
Expand Down Expand Up @@ -284,8 +296,7 @@ def representation(self,
source: DOFDescriptorLike = None,
target: DOFDescriptorLike = None,
**kwargs: Operand) -> ArithmeticExpression:
sqrt_w = self.get_sqrt_weight(source)
inv_sqrt_w_u = sym.cse(u/sqrt_w)
inv_sqrt_w_u = self.get_weighted_density(u, source)

if map_potentials is None:
def default_map_potentials(x: ArithmeticExpression) -> ArithmeticExpression:
Expand Down Expand Up @@ -319,7 +330,7 @@ def operator(self,

dofdesc = sym.as_dofdesc(dofdesc)
sqrt_w = self.get_sqrt_weight(dofdesc)
inv_sqrt_w_u = sym.cse(u/sqrt_w)
inv_sqrt_w_u = self.get_weighted_density(u, dofdesc)

if self.is_unique_only_up_to_constant():
# The exterior Dirichlet operator in this representation
Expand Down Expand Up @@ -451,8 +462,7 @@ def default_map_potentials(x: ArithmeticExpression) -> ArithmeticExpression:
from sumpy.kernel import LaplaceKernel
laplace = LaplaceKernel(self.dim)

sqrt_w = self.get_sqrt_weight(source)
inv_sqrt_w_u = sym.cse(u/sqrt_w)
inv_sqrt_w_u = self.get_weighted_density(u, source)
laplace_s_inv_sqrt_w_u = sym.cse(
sym.S(laplace, inv_sqrt_w_u,
qbx_forced_limit=+1,
Expand Down Expand Up @@ -485,7 +495,7 @@ def operator(self,

dofdesc = sym.as_dofdesc(dofdesc)
sqrt_w = self.get_sqrt_weight(dofdesc)
inv_sqrt_w_u = sym.cse(u/sqrt_w)
inv_sqrt_w_u = self.get_weighted_density(u, dofdesc)
laplace_s_inv_sqrt_w_u = sym.cse(
sym.S(laplace, inv_sqrt_w_u,
qbx_forced_limit=+1,
Expand Down
51 changes: 47 additions & 4 deletions pytential/symbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@
NablaComponent,
)
from pymbolic.primitives import ( # noqa: N813
CommonSubexpression,
Variable as var,
cse_scope as cse_scope_base,
expr_dataclass,
make_common_subexpression as cse,
make_common_subexpression as _cse,
make_sym_vector,
)
from pymbolic.typing import ArithmeticExpression
Expand Down Expand Up @@ -92,7 +93,7 @@

import modepy as mp
from pymbolic.mapper.stringifier import StringifyMapper
from pymbolic.primitives import CommonSubexpression, Quotient
from pymbolic.primitives import Quotient


__doc__ = """
Expand Down Expand Up @@ -306,6 +307,15 @@
Operators
^^^^^^^^^

.. autofunction:: cse

.. autoclass:: BremerWeightedDensity
:show-inheritance:
:undoc-members:
:members: mapper_method

.. autofunction:: bremer_weighted_density

.. autoclass:: Interpolation
:show-inheritance:
:undoc-members:
Expand Down Expand Up @@ -404,7 +414,7 @@

"IsShapeClass", "QWeight", "nodes", "parametrization_derivative",
"parametrization_derivative_matrix", "pseudoscalar", "area_element",
"sqrt_jac_q_weight", "normal", "mean_curvature",
"sqrt_jac_q_weight", "normal", "bremer_weighted_density", "mean_curvature",
"first_fundamental_form", "second_fundamental_form", "shape_operator",

"expansion_radii", "expansion_centers", "h_max", "weights_and_area_elements",
Expand All @@ -413,7 +423,8 @@
"ElementwiseMax", "integral", "Ones", "ones_vec", "area", "mean",
"IterativeInverse",

"Interpolation", "interpolate",
"Interpolation", "interpolate", "BremerWeightedDensity",
"bremer_weighted_density",

"Derivative",

Expand Down Expand Up @@ -527,6 +538,38 @@ class ErrorExpression(ExpressionNode):
"""The error message to raise when this expression is encountered."""


@expr_dataclass()
class BremerWeightedDensity(ExpressionNode):
"""A right-preconditioned density with Bremer/L2 quadrature weighting.

.. autoattribute:: operand
"""

operand: ArithmeticExpression


@for_each_expression
def bremer_weighted_density(
operand: ArithmeticExpression,
) -> ArithmeticExpression:
"""Mark *operand* as a Bremer/L2-weighted density."""

return BremerWeightedDensity(operand)


def cse(
expr: Operand,
prefix: str | None = None,
scope: str | None = None,
*,
wrap_vars: bool = True,
) -> Operand:
if scope is None:
scope = cse_scope.EVALUATION

return cast("Operand", _cse(expr, prefix, scope, wrap_vars=wrap_vars))


def make_sym_mv(name: str, num_components: int) -> MultiVector[ArithmeticExpression]:
return MultiVector(make_sym_vector(name, num_components))

Expand Down
3 changes: 1 addition & 2 deletions test/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,7 @@ def test_mapper_int_g_term_collector(op_name, k=0):
expr_only_intgs = IntGTermCollector()(expr)

# FIXME: how to check this did something?
sigma = sym.cse(op.get_density_var("sigma") / op.get_sqrt_weight(),
scope=sym.cse_scope.EVALUATION)
sigma = op.get_weighted_density(op.get_density_var("sigma"))
if op_name == "dirichlet":
expected_expr = -1 * sym.D(op.kernel, sigma, qbx_forced_limit="avg")
elif op_name == "neumann":
Expand Down
Loading