diff --git a/docs/SDK_GUIDE.md b/docs/SDK_GUIDE.md index 0a29f9b..ed73e4a 100644 --- a/docs/SDK_GUIDE.md +++ b/docs/SDK_GUIDE.md @@ -193,6 +193,45 @@ query = ( ) ``` +### Filtering on Table Calculations + +Define a table calculation, add it to the query, then filter on it with the same +operators you use for dimensions. Table-calculation conditions are sent under +`filters.tableCalculations`: + +```python +from lightdash import TableCalculation + +profit_ratio = TableCalculation( + name="profit_ratio", + sql="${orders.profit} / ${orders.revenue}", +) + +query = ( + model.query() + .metrics(model.metrics.revenue, model.metrics.profit) + .table_calculations(profit_ratio) + .filter(profit_ratio > 0.2) # only rows where the ratio exceeds 20% +) +``` + +Dimension and table-calculation filters can be combined freely; each is +serialized under its own key: + +```python +query = ( + model.query() + .table_calculations(profit_ratio) + .filter((model.dimensions.country == "USA") & (profit_ratio > 0.2)) +) +``` + +> **Note on `type`:** `TableCalculation` defaults `type="number"`. The data +> type must be set for filter operators to compile — an untyped calculation is +> treated as a string by the API and rejects numeric operators like `>` or +> `between`. For a non-numeric calculation, pass `type="string"` (or `date`, +> `timestamp`, `boolean`). + --- ## Dimensions and Metrics diff --git a/lightdash/__init__.py b/lightdash/__init__.py index e876172..90ce09c 100644 --- a/lightdash/__init__.py +++ b/lightdash/__init__.py @@ -12,7 +12,8 @@ ) from lightdash.query import QueryResult from lightdash.sorting import Sort -from lightdash.filter import DimensionFilter, CompositeFilter +from lightdash.filter import DimensionFilter, TableCalculationFilter, CompositeFilter +from lightdash.table_calculations import TableCalculation from lightdash.sql_runner import SqlResult from lightdash.results import ResultSet, BaseResult @@ -25,7 +26,9 @@ 'QueryResult', 'Sort', 'DimensionFilter', + 'TableCalculationFilter', 'CompositeFilter', + 'TableCalculation', 'SqlResult', 'ResultSet', 'BaseResult', diff --git a/lightdash/filter.py b/lightdash/filter.py index cbd4d31..4f8f68d 100644 --- a/lightdash/filter.py +++ b/lightdash/filter.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from lightdash.dimensions import Dimension + from lightdash.table_calculations import TableCalculation numeric_filters = [ "isNull", @@ -57,8 +58,24 @@ ) +class _FieldFilterMixin: + """Shared ``&`` / ``|`` combination behavior for single-field filters.""" + + def __and__(self, other: Union["FieldFilter", "CompositeFilter"]) -> "CompositeFilter": + """Combine filters with AND: filter1 & filter2""" + if isinstance(other, CompositeFilter): + return CompositeFilter(filters=[self] + list(other.filters), aggregation="and") + return CompositeFilter(filters=[self, other], aggregation="and") + + def __or__(self, other: Union["FieldFilter", "CompositeFilter"]) -> "CompositeFilter": + """Combine filters with OR: filter1 | filter2""" + if isinstance(other, CompositeFilter): + return CompositeFilter(filters=[self] + list(other.filters), aggregation="or") + return CompositeFilter(filters=[self, other], aggregation="or") + + @dataclass -class DimensionFilter: +class DimensionFilter(_FieldFilterMixin): field: "Dimension" operator: str values: Union[str, int, float, List[str], List[int], List[float]] @@ -87,27 +104,61 @@ def to_dict(self) -> Dict[str, Union[str, List[str]]]: "values": self.values, } - def __and__(self, other: Union["DimensionFilter", "CompositeFilter"]) -> "CompositeFilter": - """Combine filters with AND: filter1 & filter2""" - if isinstance(other, CompositeFilter): - return CompositeFilter(filters=[self] + list(other.filters), aggregation="and") - return CompositeFilter(filters=[self, other], aggregation="and") - def __or__(self, other: Union["DimensionFilter", "CompositeFilter"]) -> "CompositeFilter": - """Combine filters with OR: filter1 | filter2""" - if isinstance(other, CompositeFilter): - return CompositeFilter(filters=[self] + list(other.filters), aggregation="or") - return CompositeFilter(filters=[self, other], aggregation="or") +@dataclass +class TableCalculationFilter(_FieldFilterMixin): + """A filter targeting a table calculation. + + Table calculations are referenced by name (no model prefix) and serialize + under ``filters.tableCalculations`` in the query payload. + """ + + field: Union[str, "TableCalculation"] + operator: str + values: Union[str, int, float, List[str], List[int], List[float]] + + def __post_init__(self): + from lightdash.table_calculations import TableCalculation + + if not isinstance(self.values, list): + self.values = [self.values] + + if self.operator not in allowed_values: + raise ValueError( + f"Invalid operator '{self.operator}'. " + f"Must be one of: {', '.join(sorted(allowed_values))}" + ) + + if not isinstance(self.field, (str, TableCalculation)): + raise TypeError( + "field must be a TableCalculation object or table calculation name, " + f"got {type(self.field).__name__}" + ) + + @property + def field_id(self) -> str: + return self.field if isinstance(self.field, str) else self.field.field_id + + def to_dict(self) -> Dict[str, Union[str, List[str]]]: + return { + "target": {"fieldId": self.field_id}, + "operator": self.operator, + "values": self.values, + } + + +FieldFilter = Union[DimensionFilter, TableCalculationFilter] @dataclass class CompositeFilter: """ - Filters are a list of dimension filters that are applied to a query. + Filters are a list of field filters (on dimensions and table calculations) + that are applied to a query. Later this will also represent complex filters with AND, OR, NOT, etc. """ - filters: List[DimensionFilter] = field(default_factory=list) + filters: List[FieldFilter] = field(default_factory=list) aggregation: str = "and" def __post_init__(self): @@ -117,17 +168,26 @@ def __post_init__(self): ) def to_dict(self): - out = [] + dimensions = [] + table_calculations = [] for f in self.filters: # Check that the filter is not a composite filter if not hasattr(f, "field"): raise TypeError("Multi-level filter composites not supported yet") # Multiple filters may target the same field, e.g. a date range # expressed as (dim >= start) & (dim <= end). - out.append(f.to_dict()) - return {"dimensions": {self.aggregation: out}} - - def __and__(self, other: Union[DimensionFilter, "CompositeFilter"]) -> "CompositeFilter": + if isinstance(f, TableCalculationFilter): + table_calculations.append(f.to_dict()) + else: + dimensions.append(f.to_dict()) + out = {"dimensions": {self.aggregation: dimensions}} + # Only include the tableCalculations group when present, so existing + # dimension-only payloads are unchanged. + if table_calculations: + out["tableCalculations"] = {self.aggregation: table_calculations} + return out + + def __and__(self, other: Union[FieldFilter, "CompositeFilter"]) -> "CompositeFilter": """Combine with another filter using AND: composite & filter""" if isinstance(other, CompositeFilter): # Flatten if both are AND composites @@ -143,7 +203,7 @@ def __and__(self, other: Union[DimensionFilter, "CompositeFilter"]) -> "Composit return CompositeFilter(filters=list(self.filters) + [other], aggregation="and") return CompositeFilter(filters=list(self.filters) + [other], aggregation="and") - def __or__(self, other: Union[DimensionFilter, "CompositeFilter"]) -> "CompositeFilter": + def __or__(self, other: Union[FieldFilter, "CompositeFilter"]) -> "CompositeFilter": """Combine with another filter using OR: composite | filter""" if isinstance(other, CompositeFilter): # Flatten if both are OR composites diff --git a/lightdash/models.py b/lightdash/models.py index 3cd3ff9..90534e3 100644 --- a/lightdash/models.py +++ b/lightdash/models.py @@ -128,6 +128,7 @@ def query( filters: Optional[Union[DimensionFilter, CompositeFilter]] = None, sort: Optional[Union[Sort, Sequence[Sort]]] = None, limit: int = 500, + table_calculations: Optional[Sequence[Any]] = None, ) -> Query: """ Create a query against this model. @@ -143,6 +144,8 @@ def query( filters: Optional filters to apply to the query. sort: Optional Sort object or sequence of Sort objects to order results. limit: Maximum number of rows to return. + table_calculations: Optional sequence of TableCalculation objects or + raw dicts to include in the query. Returns: A Query object that can be used to fetch results or build further. @@ -186,6 +189,7 @@ def query( filters=filters, sort=sort_seq, limit=limit, + table_calculations=table_calculations, ) def list_metrics(self) -> List["Metric"]: diff --git a/lightdash/query.py b/lightdash/query.py index c27a8f7..0de69bf 100644 --- a/lightdash/query.py +++ b/lightdash/query.py @@ -5,7 +5,7 @@ from .dimensions import Dimension from .metrics import Metric -from .filter import DimensionFilter, CompositeFilter +from .filter import DimensionFilter, TableCalculationFilter, CompositeFilter from .sorting import Sort from .types import Model from .exceptions import QueryError, QueryTimeout, QueryCancelled @@ -376,10 +376,10 @@ def __init__( self._dimensions = tuple(dimensions) if dimensions else () self._limit = limit - # Handle filters - normalize to CompositeFilter if DimensionFilter is passed + # Handle filters - normalize to CompositeFilter if a single filter is passed if filters is None: self._filters = None - elif isinstance(filters, DimensionFilter): + elif isinstance(filters, (DimensionFilter, TableCalculationFilter)): self._filters = CompositeFilter(filters=[filters]) else: self._filters = filters @@ -448,7 +448,34 @@ def dimensions(self, *dimensions: Union[str, Dimension]) -> "Query": """ return self._clone(dimensions=self._dimensions + dimensions) - def filter(self, filter: Union[DimensionFilter, CompositeFilter]) -> "Query": + def table_calculations(self, *table_calculations: Any) -> "Query": + """ + Add table calculations to the query. + + Returns a new Query with the specified table calculations added. + + Args: + *table_calculations: TableCalculation objects or raw dicts to add + + Returns: + A new Query with the table calculations added + + Example: + profit_ratio = TableCalculation( + name="profit_ratio", + sql="${orders.profit} / ${orders.revenue}", + ) + query = ( + model.query() + .metrics(model.metrics.revenue, model.metrics.profit) + .table_calculations(profit_ratio) + .filter(profit_ratio > 0.2) + ) + """ + existing = tuple(self._table_calculations) if self._table_calculations else () + return self._clone(table_calculations=existing + table_calculations) + + def filter(self, filter: Union[DimensionFilter, TableCalculationFilter, CompositeFilter]) -> "Query": """ Add a filter to the query. @@ -456,7 +483,7 @@ def filter(self, filter: Union[DimensionFilter, CompositeFilter]) -> "Query": Returns a new Query with the filter added. Args: - filter: A DimensionFilter or CompositeFilter to apply + filter: A DimensionFilter, TableCalculationFilter or CompositeFilter to apply Returns: A new Query with the filter added @@ -470,13 +497,13 @@ def filter(self, filter: Union[DimensionFilter, CompositeFilter]) -> "Query": """ if self._filters is None: # First filter - if isinstance(filter, DimensionFilter): + if isinstance(filter, (DimensionFilter, TableCalculationFilter)): new_filters = CompositeFilter(filters=[filter]) else: new_filters = filter else: # Combine with existing filters using AND - if isinstance(filter, DimensionFilter): + if isinstance(filter, (DimensionFilter, TableCalculationFilter)): # Add to existing CompositeFilter's list new_filters = CompositeFilter( filters=list(self._filters.filters) + [filter], diff --git a/lightdash/table_calculations.py b/lightdash/table_calculations.py new file mode 100644 index 0000000..45c3017 --- /dev/null +++ b/lightdash/table_calculations.py @@ -0,0 +1,150 @@ +""" +Table calculations for Lightdash queries. + +A ``TableCalculation`` is a row-by-row expression evaluated on query results. +It serializes as a ``SqlTableCalculation`` (``{name, displayName, sql, type}``) +in the query payload, and can be referenced in filters via the comparison +operators below, mirroring the ``Dimension`` filter API. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from .filter import TableCalculationFilter + + +@dataclass +class TableCalculation: + """A Lightdash table calculation defined by a SQL expression. + + ``type`` is the result data type (``number``, ``string``, ``date``, + ``timestamp`` or ``boolean``). It defaults to ``number`` because filtering + is the common reason to reference a calculation, and numeric/comparison + operators only compile when the calc is typed - an untyped calc is treated + as ``string`` by the API and rejects ``>``, ``between``, etc. Set it + explicitly for non-numeric calculations. + """ + name: str + sql: str + display_name: Optional[str] = None + type: str = "number" + + def __hash__(self) -> int: + """Make TableCalculation hashable for use in sets and dict keys.""" + return hash((self.name, self.sql, self.display_name, self.type)) + + def __str__(self) -> str: + return f"TableCalculation({self.name})" + + def _repr_pretty_(self, p, cycle): + if cycle: + p.text("TableCalculation(...)") + else: + p.text(str(self)) + + @property + def field_id(self) -> str: + """Table calculations are referenced by their name (no model prefix).""" + return self.name + + def to_dict(self) -> Dict[str, str]: + """Serialize as a SqlTableCalculation for the query payload.""" + return { + "name": self.name, + "displayName": self.display_name or self.name, + "sql": self.sql, + "type": self.type, + } + + # ------------------------------------------------------------------------- + # Filter operator overloading (mirrors Dimension) + # ------------------------------------------------------------------------- + + def __eq__(self, other: Any) -> Union[bool, "TableCalculationFilter"]: # type: ignore[override] + """Create equals filter: calc == value or calc == [a, b]""" + if isinstance(other, TableCalculation): + # Allow normal dataclass equality checks + return ( + self.name == other.name + and self.sql == other.sql + and self.display_name == other.display_name + and self.type == other.type + ) + from .filter import TableCalculationFilter + values = other if isinstance(other, list) else [other] + return TableCalculationFilter(field=self, operator="equals", values=values) + + def __ne__(self, other: Any) -> Union[bool, "TableCalculationFilter"]: # type: ignore[override] + """Create not equals filter: calc != value""" + if isinstance(other, TableCalculation): + return not self.__eq__(other) + from .filter import TableCalculationFilter + values = other if isinstance(other, list) else [other] + return TableCalculationFilter(field=self, operator="notEquals", values=values) + + def __gt__(self, other: Any) -> "TableCalculationFilter": + """Create greater than filter: calc > value""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="greaterThan", values=[other]) + + def __lt__(self, other: Any) -> "TableCalculationFilter": + """Create less than filter: calc < value""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="lessThan", values=[other]) + + def __ge__(self, other: Any) -> "TableCalculationFilter": + """Create >= filter: calc >= value""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="greaterThanOrEqual", values=[other]) + + def __le__(self, other: Any) -> "TableCalculationFilter": + """Create <= filter: calc <= value""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="lessThanOrEqual", values=[other]) + + def in_(self, values: List[Any]) -> "TableCalculationFilter": + """Create 'in' filter: calc.in_([1, 2])""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="equals", values=values) + + def not_in(self, values: List[Any]) -> "TableCalculationFilter": + """Create 'not in' filter: calc.not_in([1, 2])""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="notEquals", values=values) + + def contains(self, value: str) -> "TableCalculationFilter": + """Create contains filter: calc.contains('substring')""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="include", values=[value]) + + def starts_with(self, value: str) -> "TableCalculationFilter": + """Create starts with filter: calc.starts_with('prefix')""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="startsWith", values=[value]) + + def ends_with(self, value: str) -> "TableCalculationFilter": + """Create ends with filter: calc.ends_with('suffix')""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="endsWith", values=[value]) + + def is_null(self) -> "TableCalculationFilter": + """Create is null filter: calc.is_null()""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="isNull", values=[]) + + def is_not_null(self) -> "TableCalculationFilter": + """Create is not null filter: calc.is_not_null()""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="notNull", values=[]) + + def between(self, start: Any, end: Any) -> "TableCalculationFilter": + """Create between filter: calc.between(10, 100)""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="inBetween", values=[start, end]) + + def not_between(self, start: Any, end: Any) -> "TableCalculationFilter": + """Create not between filter: calc.not_between(10, 100)""" + from .filter import TableCalculationFilter + return TableCalculationFilter(field=self, operator="notInBetween", values=[start, end]) diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py index 0b04e94..308224b 100644 --- a/tests/test_acceptance.py +++ b/tests/test_acceptance.py @@ -347,6 +347,59 @@ def test_query_requires_client(client_params): ).to_records() +def test_query_table_calculation_filter(client): + """E2E: a filter on a table calculation is applied server-side (#21).""" + from lightdash import TableCalculation + + # Find a model that has at least one metric and one dimension. + model = metric = dim = None + for m in client.list_models(): + metrics = m.list_metrics() + dims = m.list_dimensions() + if metrics and dims: + model, metric, dim = m, metrics[0], dims[0] + break + if model is None: + pytest.skip("No model with both a metric and a dimension available") + + # A calc that simply copies the metric, so we can predict filter results. + calc = TableCalculation( + name="calc_copy", + sql="${%s.%s}" % (metric.model_name, metric.name), + ) + + base = model.query( + dimensions=[dim.field_id], + metrics=[metric.field_id], + table_calculations=[calc], + limit=200, + ).execute().to_records() + + values = sorted( + r["calc_copy"] for r in base + if isinstance(r.get("calc_copy"), (int, float)) + ) + if len(values) < 3: + pytest.skip("Not enough numeric rows to exercise a table-calc filter") + + threshold = values[len(values) // 2] # median + expected = sum(1 for v in values if v > threshold) + + filtered = ( + model.query() + .dimensions(dim.field_id) + .metrics(metric.field_id) + .table_calculations(calc) + .filter(calc > threshold) + .limit(200) + ).execute().to_records() + + # Every returned row satisfies the predicate, and the count matches the + # client-side expectation - proving the filter ran server-side. + assert all(r["calc_copy"] > threshold for r in filtered) + assert len(filtered) == expected + + def test_metric_field_id(): """Test that metrics generate correct field IDs.""" metric = Metric( diff --git a/tests/test_table_calculations.py b/tests/test_table_calculations.py new file mode 100644 index 0000000..1bf2de2 --- /dev/null +++ b/tests/test_table_calculations.py @@ -0,0 +1,252 @@ +""" +Tests for table calculations and table calculation filters (issue #21). + +Verifies that table calculations can be defined, added to queries, and used in +filters that serialize under ``filters.tableCalculations`` — the shape confirmed +against the Lightdash ``Filters`` / ``SqlTableCalculation`` types. +""" + +import pytest +from lightdash.dimensions import Dimension +from lightdash.filter import DimensionFilter, TableCalculationFilter, CompositeFilter +from lightdash.models import Model +from lightdash.table_calculations import TableCalculation + + +@pytest.fixture +def model(): + """Create a test model without client reference.""" + return Model( + name="test_model", + type="default", + database_name="test_db", + schema_name="test_schema", + ) + + +@pytest.fixture +def dimension(): + """Create a test dimension.""" + return Dimension( + name="country", + model_name="test_model", + label="Country", + description="Customer country", + ) + + +@pytest.fixture +def calc(): + """Create a test table calculation.""" + return TableCalculation( + name="profit_ratio", + sql="${test_model.profit} / ${test_model.revenue}", + display_name="Profit Ratio", + ) + + +class TestTableCalculation: + """Test the TableCalculation class.""" + + def test_field_id_is_name(self, calc): + """Table calculations are referenced by name, without a model prefix.""" + assert calc.field_id == "profit_ratio" + + def test_to_dict_matches_sql_table_calculation(self, calc): + """Serialization matches the SqlTableCalculation shape {name, displayName, sql, type}.""" + assert calc.to_dict() == { + "name": "profit_ratio", + "displayName": "Profit Ratio", + "sql": "${test_model.profit} / ${test_model.revenue}", + "type": "number", + } + + def test_to_dict_display_name_defaults_to_name(self): + """displayName falls back to name when not provided.""" + calc = TableCalculation(name="my_calc", sql="1 + 1") + assert calc.to_dict()["displayName"] == "my_calc" + + def test_type_defaults_to_number(self, calc): + """type defaults to number so comparison filters compile server-side.""" + assert calc.type == "number" + assert calc.to_dict()["type"] == "number" + + def test_type_can_be_overridden(self): + """Non-numeric calculations can declare their type.""" + calc = TableCalculation(name="label", sql="'x'", type="string") + assert calc.to_dict()["type"] == "string" + + def test_equality_between_calculations(self, calc): + """Comparing two TableCalculations returns bool, not a filter.""" + same = TableCalculation( + name="profit_ratio", + sql="${test_model.profit} / ${test_model.revenue}", + display_name="Profit Ratio", + ) + different = TableCalculation(name="other", sql="1") + assert calc == same + assert calc != different + + def test_hashable(self, calc): + """TableCalculations can be used in sets and as dict keys.""" + same = TableCalculation( + name="profit_ratio", + sql="${test_model.profit} / ${test_model.revenue}", + display_name="Profit Ratio", + ) + assert hash(calc) == hash(same) + assert calc in {calc} + + +class TestTableCalculationOperators: + """Test filter creation via operators on TableCalculation.""" + + def test_equals_operator(self, calc): + result = calc == 0.5 + assert isinstance(result, TableCalculationFilter) + assert result.operator == "equals" + assert result.values == [0.5] + + def test_not_equals_operator(self, calc): + result = calc != 0.5 + assert isinstance(result, TableCalculationFilter) + assert result.operator == "notEquals" + + def test_comparison_operators(self, calc): + assert (calc > 1).operator == "greaterThan" + assert (calc >= 1).operator == "greaterThanOrEqual" + assert (calc < 1).operator == "lessThan" + assert (calc <= 1).operator == "lessThanOrEqual" + + def test_helper_methods(self, calc): + assert calc.in_([1, 2]).operator == "equals" + assert calc.not_in([1, 2]).operator == "notEquals" + assert calc.contains("a").operator == "include" + assert calc.starts_with("a").operator == "startsWith" + assert calc.ends_with("a").operator == "endsWith" + assert calc.is_null().operator == "isNull" + assert calc.is_not_null().operator == "notNull" + assert calc.between(1, 2).operator == "inBetween" + assert calc.not_between(1, 2).operator == "notInBetween" + + +class TestTableCalculationFilter: + """Test the TableCalculationFilter class.""" + + def test_to_dict_with_calculation_object(self, calc): + result = (calc > 0.2).to_dict() + assert result == { + "target": {"fieldId": "profit_ratio"}, + "operator": "greaterThan", + "values": [0.2], + } + + def test_field_as_string(self): + """A table calculation can be referenced by name.""" + f = TableCalculationFilter( + field="profit_ratio", operator="greaterThan", values=[0.2] + ) + assert f.to_dict()["target"]["fieldId"] == "profit_ratio" + + def test_invalid_operator_raises(self, calc): + with pytest.raises(ValueError, match="Invalid operator"): + TableCalculationFilter(field=calc, operator="bogus", values=[1]) + + def test_invalid_field_raises(self): + with pytest.raises(TypeError, match="field must be a TableCalculation"): + TableCalculationFilter(field=123, operator="equals", values=[1]) + + def test_scalar_values_wrapped_in_list(self, calc): + f = TableCalculationFilter(field=calc, operator="equals", values=0.5) + assert f.values == [0.5] + + +class TestTableCalculationFilterSerialization: + """Test composite serialization under filters.tableCalculations.""" + + def test_calc_only_composite(self, calc): + composite = CompositeFilter(filters=[calc > 0.2]) + result = composite.to_dict() + assert result["dimensions"] == {"and": []} + rules = result["tableCalculations"]["and"] + assert len(rules) == 1 + assert rules[0]["target"]["fieldId"] == "profit_ratio" + + def test_mixed_composite(self, calc, dimension): + """Dimension and table calc filters serialize under separate keys.""" + composite = (dimension == "USA") & (calc > 0.2) + result = composite.to_dict() + dim_rules = result["dimensions"]["and"] + calc_rules = result["tableCalculations"]["and"] + assert len(dim_rules) == 1 + assert dim_rules[0]["target"]["fieldId"] == "test_model_country" + assert len(calc_rules) == 1 + assert calc_rules[0]["target"]["fieldId"] == "profit_ratio" + + def test_or_aggregation(self, calc): + composite = (calc > 0.8) | (calc < 0.2) + result = composite.to_dict() + assert len(result["tableCalculations"]["or"]) == 2 + + def test_no_calc_filters_omits_key(self, dimension): + """tableCalculations key is omitted when no calc filters exist.""" + composite = CompositeFilter(filters=[dimension == "USA"]) + assert "tableCalculations" not in composite.to_dict() + + +class TestQueryIntegration: + """Test table calculations in the query builder.""" + + def test_table_calculations_method_adds_calcs(self, model, calc): + query = model.query().table_calculations(calc) + assert query._table_calculations == (calc,) + + def test_table_calculations_accumulate(self, model, calc): + other = TableCalculation(name="other", sql="1") + query = model.query().table_calculations(calc).table_calculations(other) + assert query._table_calculations == (calc, other) + + def test_table_calculations_returns_new_query(self, model, calc): + query1 = model.query() + query2 = query1.table_calculations(calc) + assert query1 is not query2 + assert query1._table_calculations is None + + def test_query_kwarg(self, model, calc): + query = model.query(table_calculations=[calc]) + assert tuple(query._table_calculations) == (calc,) + + def test_payload_includes_table_calculations(self, model, calc): + query = model.query().table_calculations(calc) + payload = query._build_payload() + assert payload["tableCalculations"] == [calc.to_dict()] + + def test_payload_accepts_raw_dicts(self, model): + raw = {"name": "my_calc", "displayName": "My Calc", "sql": "1 + 1"} + query = model.query().table_calculations(raw) + payload = query._build_payload() + assert payload["tableCalculations"] == [raw] + + def test_filter_with_table_calculation(self, model, calc): + """Table calc filters flow through .filter() into the payload.""" + query = model.query().table_calculations(calc).filter(calc > 0.2) + payload = query._build_payload() + rules = payload["filters"]["tableCalculations"]["and"] + assert len(rules) == 1 + assert rules[0]["target"]["fieldId"] == "profit_ratio" + + def test_filter_combines_with_dimension_filters(self, model, calc, dimension): + query = ( + model.query() + .filter(dimension == "USA") + .filter(calc > 0.2) + ) + payload = query._build_payload() + assert len(payload["filters"]["dimensions"]["and"]) == 1 + assert len(payload["filters"]["tableCalculations"]["and"]) == 1 + + def test_filter_as_query_kwarg(self, model, calc): + """A single TableCalculationFilter can be passed as filters=...""" + query = model.query(filters=calc > 0.2) + payload = query._build_payload() + assert len(payload["filters"]["tableCalculations"]["and"]) == 1