diff --git a/pyproject.toml b/pyproject.toml index 641483d..60d5822 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sap-cloud-sdk" -version = "0.27.1" +version = "0.28.0" description = "SAP Cloud SDK for Python" readme = "README.md" license = "Apache-2.0" diff --git a/src/sap_cloud_sdk/core/odata/__init__.py b/src/sap_cloud_sdk/core/odata/__init__.py new file mode 100644 index 0000000..3694191 --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/__init__.py @@ -0,0 +1,48 @@ +"""Shared OData v4 abstractions for the SAP Cloud SDK.""" + +from sap_cloud_sdk.core.odata._async_transport import AsyncODataHttpTransport +from sap_cloud_sdk.core.odata._factory import odata_transport_from_destination +from sap_cloud_sdk.core.odata._filter import FilterExpression +from sap_cloud_sdk.core.odata._models import ODataEntity +from sap_cloud_sdk.core.odata._pagination import ODataPageIterator +from sap_cloud_sdk.core.odata._query import OrderDirection, StructuredQuery +from sap_cloud_sdk.core.odata._request_builders import ( + CreateRequestBuilder, + DeleteRequestBuilder, + GetAllRequestBuilder, + GetByKeyRequestBuilder, + UpdateRequestBuilder, +) +from sap_cloud_sdk.core.odata._transport import ODataHttpTransport +from sap_cloud_sdk.core.odata.exceptions import ( + ODataAuthError, + ODataConnectionError, + ODataCsrfError, + ODataDeserializationError, + ODataError, + ODataNotFoundError, + ODataRequestError, +) + +__all__ = [ + "AsyncODataHttpTransport", + "CreateRequestBuilder", + "DeleteRequestBuilder", + "FilterExpression", + "GetAllRequestBuilder", + "GetByKeyRequestBuilder", + "ODataAuthError", + "ODataConnectionError", + "ODataCsrfError", + "ODataDeserializationError", + "ODataEntity", + "ODataError", + "ODataHttpTransport", + "ODataNotFoundError", + "ODataPageIterator", + "ODataRequestError", + "OrderDirection", + "StructuredQuery", + "UpdateRequestBuilder", + "odata_transport_from_destination", +] diff --git a/src/sap_cloud_sdk/core/odata/_async_transport.py b/src/sap_cloud_sdk/core/odata/_async_transport.py new file mode 100644 index 0000000..c4e0670 --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_async_transport.py @@ -0,0 +1,199 @@ +"""Asynchronous HTTP transport for OData v4 services.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +import httpx + +from sap_cloud_sdk.core.odata._constants import ( + CSRF_FETCH_TIMEOUT, + CSRF_FETCH_VALUE, + CSRF_HEADER, + DEFAULT_HEADERS, + MUTATING_METHODS, + REQUEST_TIMEOUT, +) +from sap_cloud_sdk.core.odata.exceptions import ( + ODataAuthError, + ODataConnectionError, + ODataCsrfError, + ODataNotFoundError, + ODataRequestError, +) + +logger = logging.getLogger(__name__) + + +class AsyncODataHttpTransport: + """Asynchronous HTTP transport for OData v4 services. + + Mirrors :class:`~sap_cloud_sdk.core.odata._transport.ODataHttpTransport` + but uses ``httpx.AsyncClient``. Use as an async context manager:: + + async with AsyncODataHttpTransport(base_url, client) as t: + data = await t.request("GET", "BusinessPartnerSet") + + Args: + base_url: Root URL of the OData service. + client: Pre-configured ``httpx.AsyncClient``. + csrf_enabled: Whether to fetch and attach CSRF tokens on mutating + requests. + """ + + def __init__( + self, + base_url: str, + client: httpx.AsyncClient, + csrf_enabled: bool = True, + ) -> None: + self._base_url = base_url.rstrip("/") + self._client = client + self._csrf_enabled = csrf_enabled + self._csrf_token: str | None = None + self._csrf_lock = asyncio.Lock() + + async def __aenter__(self) -> "AsyncODataHttpTransport": + return self + + async def __aexit__(self, *args: Any) -> None: + await self._client.aclose() + + async def request( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + json: Any | None = None, + headers: dict[str, str] | None = None, + ) -> dict[str, Any]: + """Execute an OData request and return the parsed JSON body. + + CSRF tokens are fetched and attached automatically for mutating methods + (POST, PUT, PATCH, DELETE) when ``csrf_enabled`` is ``True``. On a + 403 response the cached token is invalidated and the request is retried + once with a fresh token. + + Args: + method: HTTP method (``"GET"``, ``"POST"``, ``"PATCH"``, etc.). + path: Entity path relative to the service base URL. + params: OData query parameters. + json: Request body serialised as JSON. + headers: Extra headers merged on top of the defaults. + + Returns: + Parsed JSON response body, or ``{}`` for 204 / empty responses. + """ + extra = dict(headers or {}) + + if method.upper() in MUTATING_METHODS and self._csrf_enabled: + extra[CSRF_HEADER] = await self._get_csrf_token() + try: + return await self._execute( + method, path, params=params, json=json, extra_headers=extra + ) + except ODataAuthError as exc: + if exc.status_code == 403: + await self._invalidate_csrf_token() + extra[CSRF_HEADER] = await self._get_csrf_token() + return await self._execute( + method, path, params=params, json=json, extra_headers=extra + ) + raise + + return await self._execute( + method, path, params=params, json=json, extra_headers=extra + ) + + def absolute_url(self, path: str) -> str: + return self._base_url + "/" + path.lstrip("/") + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + async def _get_csrf_token(self) -> str: + async with self._csrf_lock: + if self._csrf_token is not None: + return self._csrf_token + + token = await self._fetch_csrf_token() + async with self._csrf_lock: + if self._csrf_token is None: + self._csrf_token = token + return self._csrf_token # type: ignore[return-value] + + async def _invalidate_csrf_token(self) -> None: + async with self._csrf_lock: + self._csrf_token = None + + async def _fetch_csrf_token(self) -> str: + url = self._base_url + "/" + try: + resp = await self._client.get( + url, + headers={CSRF_HEADER: CSRF_FETCH_VALUE}, + timeout=CSRF_FETCH_TIMEOUT, + ) + except httpx.RequestError as exc: + raise ODataCsrfError(f"Async CSRF fetch failed: {exc}") from exc + + token = resp.headers.get(CSRF_HEADER, "") + if not token: + raise ODataCsrfError( + f"Service did not return a CSRF token (HTTP {resp.status_code})" + ) + return token + + async def _execute( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + json: Any | None = None, + extra_headers: dict[str, str] | None = None, + ) -> dict[str, Any]: + url = self.absolute_url(path) + req_headers = {**DEFAULT_HEADERS, **(extra_headers or {})} + + logger.debug("%s %s params=%s", method, url, params) + try: + resp = await self._client.request( + method=method, + url=url, + headers=req_headers, + params=params, + json=json, + timeout=REQUEST_TIMEOUT, + ) + except httpx.RequestError as exc: + raise ODataConnectionError(f"Request failed: {exc}") from exc + + self._raise_for_status(resp) + + if resp.status_code == 204 or not resp.content: + return {} + return resp.json() + + def _raise_for_status(self, response: httpx.Response) -> None: + if response.status_code == 404: + raise ODataNotFoundError(_HttpxResponseAdapter(response)) + if response.status_code in (401, 403): + raise ODataAuthError(_HttpxResponseAdapter(response)) + if not (200 <= response.status_code < 300): + raise ODataRequestError(_HttpxResponseAdapter(response)) + + +class _HttpxResponseAdapter: + """Minimal adapter so httpx.Response can be passed to ODataRequestError.""" + + def __init__(self, response: httpx.Response) -> None: + self.status_code = response.status_code + self._response = response + + def json(self) -> Any: + return self._response.json() diff --git a/src/sap_cloud_sdk/core/odata/_constants.py b/src/sap_cloud_sdk/core/odata/_constants.py new file mode 100644 index 0000000..874d106 --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_constants.py @@ -0,0 +1,52 @@ +"""Shared constants for the OData v4 HTTP layer.""" + +from __future__ import annotations + +# --------------------------------------------------------------------------- +# CSRF +# --------------------------------------------------------------------------- + +CSRF_HEADER = "X-CSRF-Token" +CSRF_FETCH_VALUE = "Fetch" +CSRF_FETCH_TIMEOUT = 10 + +# --------------------------------------------------------------------------- +# HTTP +# --------------------------------------------------------------------------- + +REQUEST_TIMEOUT = 30 + +MUTATING_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) + +DEFAULT_HEADERS: dict[str, str] = { + "Accept": "application/json", + "Content-Type": "application/json", +} + +# HTTP method literals +GET = "GET" +POST = "POST" +PUT = "PUT" +PATCH = "PATCH" +DELETE = "DELETE" + +# Standard conditional-request header +IF_MATCH_HEADER = "If-Match" + +# --------------------------------------------------------------------------- +# OData system query options +# --------------------------------------------------------------------------- + +QUERY_SELECT = "$select" +QUERY_FILTER = "$filter" +QUERY_ORDERBY = "$orderby" +QUERY_TOP = "$top" +QUERY_SKIP = "$skip" +QUERY_EXPAND = "$expand" + +# --------------------------------------------------------------------------- +# OData response envelope keys +# --------------------------------------------------------------------------- + +RESPONSE_VALUE = "value" +RESPONSE_NEXT_LINK = "@odata.nextLink" diff --git a/src/sap_cloud_sdk/core/odata/_csrf.py b/src/sap_cloud_sdk/core/odata/_csrf.py new file mode 100644 index 0000000..ff99705 --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_csrf.py @@ -0,0 +1,72 @@ +"""CSRF token fetch-and-cache for OData v4 mutating requests.""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING + +import requests as _requests + +from sap_cloud_sdk.core.odata._constants import ( + CSRF_FETCH_TIMEOUT, + CSRF_FETCH_VALUE, + CSRF_HEADER, +) +from sap_cloud_sdk.core.odata.exceptions import ODataCsrfError + +if TYPE_CHECKING: + from ._transport import ODataHttpTransport + + +class CsrfTokenProvider: + """Fetch and cache a CSRF token for one OData service root. + + The token is fetched lazily on the first mutating request and cached + until it is invalidated (typically after a 403 response). + + Thread-safe: internal state is protected by a :class:`threading.Lock`. + + Args: + transport: The owning :class:`ODataHttpTransport` whose session and + base URL are used to perform the CSRF-fetch GET. + """ + + def __init__(self, transport: "ODataHttpTransport") -> None: + self._transport = transport + self._token: str | None = None + self._lock = threading.Lock() + + def get(self) -> str: + """Return the cached CSRF token, fetching from the service if needed.""" + with self._lock: + if self._token is not None: + return self._token + + token = self._fetch() + with self._lock: + if self._token is None: + self._token = token + return self._token + + def invalidate(self) -> None: + """Discard the cached token so the next call re-fetches.""" + with self._lock: + self._token = None + + def _fetch(self) -> str: + url = self._transport._base_url + "/" + try: + resp = self._transport._session.get( + url, + headers={CSRF_HEADER: CSRF_FETCH_VALUE}, + timeout=CSRF_FETCH_TIMEOUT, + ) + except _requests.RequestException as exc: + raise ODataCsrfError(f"CSRF fetch failed: {exc}") from exc + + token = resp.headers.get(CSRF_HEADER, "") + if not token: + raise ODataCsrfError( + f"Service did not return a CSRF token (HTTP {resp.status_code})" + ) + return token diff --git a/src/sap_cloud_sdk/core/odata/_factory.py b/src/sap_cloud_sdk/core/odata/_factory.py new file mode 100644 index 0000000..3fee85a --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_factory.py @@ -0,0 +1,80 @@ +"""Factory for building OData transports from BTP Destinations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import requests + +from sap_cloud_sdk.core.odata._transport import ODataHttpTransport + +if TYPE_CHECKING: + from sap_cloud_sdk.destination._models import Destination + + +def odata_transport_from_destination( + destination: "Destination", + *, + odata_path: str = "", + csrf_enabled: bool = True, +) -> ODataHttpTransport: + """Build an :class:`ODataHttpTransport` from a resolved BTP Destination. + + The destination's auth tokens and ERP headers are pre-baked into the + underlying ``requests.Session`` exactly as ``DestinationHttpClient`` does, + so the transport inherits whatever authentication the destination carries + (Bearer, Basic, mTLS, …). + + Args: + destination: A fully-resolved ``Destination`` object (i.e. returned by + ``DestinationClient.get_destination()`` so ``auth_tokens`` are + populated). + odata_path: Optional sub-path appended to the destination URL to form + the OData service root (e.g. ``"sap/opu/odata4/svc/"``). Useful + when the destination URL points to the host root rather than the + service root directly. + csrf_enabled: Whether to fetch and attach CSRF tokens on mutating + requests. Defaults to ``True``. + + Returns: + :class:`ODataHttpTransport` ready to pass into any request builder. + + Raises: + ValueError: If the destination has no URL or is not an HTTP destination. + + Example:: + + from sap_cloud_sdk.destination import create_client + from sap_cloud_sdk.core.odata._factory import odata_transport_from_destination + from sap_cloud_sdk.core.odata._request_builders import GetAllRequestBuilder + + dest_client = create_client() + destination = dest_client.get_destination("S4HANA_OData") + + transport = odata_transport_from_destination(destination) + results = GetAllRequestBuilder(transport, BusinessPartner).top(10).execute() + """ + from sap_cloud_sdk.destination._models import DestinationType + + if destination.type != DestinationType.HTTP: + raise ValueError( + f"odata_transport_from_destination only supports HTTP destinations, " + f"got: {destination.type}" + ) + if not destination.url: + raise ValueError( + f"Destination '{destination.name}' has no URL — cannot build OData transport" + ) + + base_url = destination.url.rstrip("/") + if odata_path: + base_url = base_url + "/" + odata_path.strip("/") + + session = requests.Session() + session.headers.update(destination.get_headers()) + + return ODataHttpTransport( + base_url=base_url, + session=session, + csrf_enabled=csrf_enabled, + ) diff --git a/src/sap_cloud_sdk/core/odata/_filter.py b/src/sap_cloud_sdk/core/odata/_filter.py new file mode 100644 index 0000000..b3e28e1 --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_filter.py @@ -0,0 +1,97 @@ +"""Filter expression DSL for OData v4 $filter query options.""" + +from __future__ import annotations + +from typing import Any + + +def _format_value(value: Any) -> str: + """Serialise a Python value to an OData v4 literal.""" + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, str): + return "'" + value.replace("'", "''") + "'" + return str(value) + + +class FilterExpression: + """Composable OData v4 ``$filter`` expression. + + Build expressions via :meth:`field` and combine them with :meth:`and_`, + :meth:`or_`, and :meth:`not_`:: + + f = ( + FilterExpression.field("Price").gt(100) + .and_(FilterExpression.field("Category").eq("Books")) + ) + str(f) # "(Price gt 100) and (Category eq 'Books')" + """ + + __slots__ = ("_expr",) + + def __init__(self, expr: str) -> None: + self._expr = expr + + @staticmethod + def field(name: str) -> "_FieldRef": + """Start a comparison expression for *name*.""" + return _FieldRef(name) + + def and_(self, other: "FilterExpression") -> "FilterExpression": + return FilterExpression(f"({self._expr}) and ({other._expr})") + + def or_(self, other: "FilterExpression") -> "FilterExpression": + return FilterExpression(f"({self._expr}) or ({other._expr})") + + def not_(self) -> "FilterExpression": + return FilterExpression(f"not ({self._expr})") + + def __str__(self) -> str: + return self._expr + + def __repr__(self) -> str: + return f"FilterExpression({self._expr!r})" + + def __eq__(self, other: object) -> bool: + if isinstance(other, FilterExpression): + return self._expr == other._expr + return NotImplemented + + def __hash__(self) -> int: + return hash(self._expr) + + +class _FieldRef: + """Intermediate object: a field name awaiting a comparison operator.""" + + __slots__ = ("_name",) + + def __init__(self, name: str) -> None: + self._name = name + + def eq(self, value: Any) -> FilterExpression: + return FilterExpression(f"{self._name} eq {_format_value(value)}") + + def ne(self, value: Any) -> FilterExpression: + return FilterExpression(f"{self._name} ne {_format_value(value)}") + + def lt(self, value: Any) -> FilterExpression: + return FilterExpression(f"{self._name} lt {_format_value(value)}") + + def le(self, value: Any) -> FilterExpression: + return FilterExpression(f"{self._name} le {_format_value(value)}") + + def gt(self, value: Any) -> FilterExpression: + return FilterExpression(f"{self._name} gt {_format_value(value)}") + + def ge(self, value: Any) -> FilterExpression: + return FilterExpression(f"{self._name} ge {_format_value(value)}") + + def contains(self, value: str) -> FilterExpression: + return FilterExpression(f"contains({self._name}, {_format_value(value)})") + + def starts_with(self, value: str) -> FilterExpression: + return FilterExpression(f"startswith({self._name}, {_format_value(value)})") + + def ends_with(self, value: str) -> FilterExpression: + return FilterExpression(f"endswith({self._name}, {_format_value(value)})") diff --git a/src/sap_cloud_sdk/core/odata/_models.py b/src/sap_cloud_sdk/core/odata/_models.py new file mode 100644 index 0000000..4a190fd --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_models.py @@ -0,0 +1,41 @@ +"""Protocol-level types for OData v4 entities.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, ClassVar + + +@dataclass +class ODataEntity: + """Base class for OData v4 entity types. + + Generated and hand-written entity dataclasses may inherit from this to + allow the transport layer's generic serialiser to reflect on key fields + and entity-set metadata without inspecting the concrete type directly. + + Subclasses declare their metadata via ClassVar annotations:: + + @dataclass + class BusinessPartner(ODataEntity): + _entity_set: ClassVar[str] = "BusinessPartnerSet" + _key_fields: ClassVar[list[str]] = ["BusinessPartnerID"] + + BusinessPartnerID: str = "" + DisplayName: str = "" + """ + + _entity_set: ClassVar[str] = "" + _key_fields: ClassVar[list[str]] = [] + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serialisable dict of this entity's fields.""" + result: dict[str, Any] = {} + for f in self.__dataclass_fields__: # type: ignore[attr-defined] + if not f.startswith("_"): + result[f] = getattr(self, f) + return result + + def key_dict(self) -> dict[str, Any]: + """Return only the key fields as a dict.""" + return {k: getattr(self, k) for k in self._key_fields} diff --git a/src/sap_cloud_sdk/core/odata/_pagination.py b/src/sap_cloud_sdk/core/odata/_pagination.py new file mode 100644 index 0000000..99cdc25 --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_pagination.py @@ -0,0 +1,42 @@ +"""Server-driven pagination via @odata.nextLink.""" + +from __future__ import annotations + +from typing import Any, Callable, Generic, Iterator, TypeVar + +from sap_cloud_sdk.core.odata._response import deserialize_collection, next_link + +T = TypeVar("T") + + +class ODataPageIterator(Generic[T]): + """Lazily yields pages of entities by following ``@odata.nextLink``. + + Args: + fetch_page: Callable that takes an absolute URL and returns a raw + JSON dict (the full OData collection response). + entity_type: Dataclass to deserialize each item into. + first_url: The initial request URL (already including query params). + """ + + def __init__( + self, + fetch_page: Callable[[str], dict[str, Any]], + entity_type: type[T], + first_url: str, + ) -> None: + self._fetch_page = fetch_page + self._entity_type = entity_type + self._first_url = first_url + + def __iter__(self) -> Iterator[list[T]]: + url: str | None = self._first_url + while url is not None: + data = self._fetch_page(url) + yield deserialize_collection(data, self._entity_type) + url = next_link(data) + + def entities(self) -> Iterator[T]: + """Yield individual entities across all pages.""" + for page in self: + yield from page diff --git a/src/sap_cloud_sdk/core/odata/_query.py b/src/sap_cloud_sdk/core/odata/_query.py new file mode 100644 index 0000000..46bce58 --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_query.py @@ -0,0 +1,145 @@ +"""Immutable OData v4 query parameter builder.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING + +from sap_cloud_sdk.core.odata._constants import ( + QUERY_EXPAND, + QUERY_FILTER, + QUERY_ORDERBY, + QUERY_SELECT, + QUERY_SKIP, + QUERY_TOP, +) + +if TYPE_CHECKING: + from ._filter import FilterExpression + + +class OrderDirection(Enum): + ASC = "asc" + DESC = "desc" + + +@dataclass(frozen=True) +class StructuredQuery: + """Immutable OData v4 query parameter builder. + + Every mutating method returns a new instance — safe to share a base query + across multiple requests:: + + base = StructuredQuery().select("ID", "Name").top(50) + page1 = base.skip(0) + page2 = base.skip(50) + page1.to_params() + # {"$select": "ID,Name", "$top": "50", "$skip": "0"} + """ + + _select: tuple[str, ...] = field(default=(), compare=True) + _filter: "FilterExpression | None" = field(default=None, compare=True) + _orderby: tuple[tuple[str, OrderDirection], ...] = field(default=(), compare=True) + _top: int | None = field(default=None, compare=True) + _skip: int | None = field(default=None, compare=True) + _expand: tuple[str, ...] = field(default=(), compare=True) + _custom: tuple[tuple[str, str], ...] = field(default=(), compare=True) + + def select(self, *fields: str) -> "StructuredQuery": + return StructuredQuery( + _select=tuple(fields), + _filter=self._filter, + _orderby=self._orderby, + _top=self._top, + _skip=self._skip, + _expand=self._expand, + _custom=self._custom, + ) + + def filter(self, expression: "FilterExpression") -> "StructuredQuery": + return StructuredQuery( + _select=self._select, + _filter=expression, + _orderby=self._orderby, + _top=self._top, + _skip=self._skip, + _expand=self._expand, + _custom=self._custom, + ) + + def order_by( + self, field_name: str, direction: OrderDirection = OrderDirection.ASC + ) -> "StructuredQuery": + return StructuredQuery( + _select=self._select, + _filter=self._filter, + _orderby=self._orderby + ((field_name, direction),), + _top=self._top, + _skip=self._skip, + _expand=self._expand, + _custom=self._custom, + ) + + def top(self, n: int) -> "StructuredQuery": + return StructuredQuery( + _select=self._select, + _filter=self._filter, + _orderby=self._orderby, + _top=n, + _skip=self._skip, + _expand=self._expand, + _custom=self._custom, + ) + + def skip(self, n: int) -> "StructuredQuery": + return StructuredQuery( + _select=self._select, + _filter=self._filter, + _orderby=self._orderby, + _top=self._top, + _skip=n, + _expand=self._expand, + _custom=self._custom, + ) + + def expand(self, *nav_properties: str) -> "StructuredQuery": + return StructuredQuery( + _select=self._select, + _filter=self._filter, + _orderby=self._orderby, + _top=self._top, + _skip=self._skip, + _expand=tuple(nav_properties), + _custom=self._custom, + ) + + def custom(self, key: str, value: str) -> "StructuredQuery": + filtered = tuple((k, v) for k, v in self._custom if k != key) + return StructuredQuery( + _select=self._select, + _filter=self._filter, + _orderby=self._orderby, + _top=self._top, + _skip=self._skip, + _expand=self._expand, + _custom=filtered + ((key, value),), + ) + + def to_params(self) -> dict[str, str]: + params: dict[str, str] = {} + if self._select: + params[QUERY_SELECT] = ",".join(self._select) + if self._filter is not None: + params[QUERY_FILTER] = str(self._filter) + if self._orderby: + params[QUERY_ORDERBY] = ",".join(f"{f} {d.value}" for f, d in self._orderby) + if self._top is not None: + params[QUERY_TOP] = str(self._top) + if self._skip is not None: + params[QUERY_SKIP] = str(self._skip) + if self._expand: + params[QUERY_EXPAND] = ",".join(self._expand) + for k, v in self._custom: + params[k] = v + return params diff --git a/src/sap_cloud_sdk/core/odata/_request_builders.py b/src/sap_cloud_sdk/core/odata/_request_builders.py new file mode 100644 index 0000000..4d4a1ed --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_request_builders.py @@ -0,0 +1,248 @@ +"""Generic CRUD request builder classes for OData v4.""" + +from __future__ import annotations + +import dataclasses +from typing import Any, Generic, Iterator, TypeVar, TYPE_CHECKING, cast +from urllib.parse import urlencode + +from sap_cloud_sdk.core.odata._constants import ( + DELETE, + GET, + IF_MATCH_HEADER, + PATCH, + POST, + PUT, +) +from sap_cloud_sdk.core.odata._filter import _format_value +from sap_cloud_sdk.core.odata._query import OrderDirection, StructuredQuery +from sap_cloud_sdk.core.odata._response import ( + deserialize_collection, + deserialize_single, +) +from sap_cloud_sdk.core.odata._pagination import ODataPageIterator +from sap_cloud_sdk.core.telemetry import Module, Operation, record_metrics + +if TYPE_CHECKING: + from sap_cloud_sdk.core.odata._filter import FilterExpression + from sap_cloud_sdk.core.odata._transport import ODataHttpTransport + +T = TypeVar("T") + + +def _entity_set_path(entity_type: type) -> str: + """Return the OData entity-set path for *entity_type*. + + Reads ``entity_type._entity_set`` when present; otherwise defaults to + the class name (common for generated types). + """ + return getattr(entity_type, "_entity_set", None) or entity_type.__name__ + + +def _build_key_segment(key: dict[str, Any]) -> str: + """Serialise *key* dict to an OData key segment, e.g. ``(ID='x',Ver=1)``.""" + if len(key) == 1: + return f"({_format_value(next(iter(key.values())))})" + parts = ",".join(f"{k}={_format_value(v)}" for k, v in key.items()) + return f"({parts})" + + +class GetAllRequestBuilder(Generic[T]): + """Fluent builder for OData collection (GetAll) requests. + + Example:: + + results = ( + GetAllRequestBuilder(transport, BusinessPartner) + .select("BusinessPartnerID", "DisplayName") + .filter(FilterExpression.field("DisplayName").contains("Acme")) + .top(50) + .execute() + ) + """ + + def __init__(self, transport: "ODataHttpTransport", entity_type: type[T]) -> None: + self._transport = transport + self._entity_type = entity_type + self._query = StructuredQuery() + + def select(self, *fields: str) -> "GetAllRequestBuilder[T]": + self._query = self._query.select(*fields) + return self + + def filter(self, expression: "FilterExpression") -> "GetAllRequestBuilder[T]": + self._query = self._query.filter(expression) + return self + + def order_by( + self, + field_name: str, + direction: OrderDirection = OrderDirection.ASC, + ) -> "GetAllRequestBuilder[T]": + self._query = self._query.order_by(field_name, direction) + return self + + def top(self, n: int) -> "GetAllRequestBuilder[T]": + self._query = self._query.top(n) + return self + + def skip(self, n: int) -> "GetAllRequestBuilder[T]": + self._query = self._query.skip(n) + return self + + def expand(self, *nav_properties: str) -> "GetAllRequestBuilder[T]": + self._query = self._query.expand(*nav_properties) + return self + + @record_metrics(Module.ODATA, Operation.ODATA_GET_ALL) + def execute(self) -> list[T]: + """Execute the request and return all matching entities.""" + path = _entity_set_path(self._entity_type) + data = self._transport.request(GET, path, params=self._query.to_params()) + return deserialize_collection(data, self._entity_type) + + def iterate_pages(self) -> Iterator[list[T]]: + """Yield pages using server-driven pagination (``@odata.nextLink``).""" + path = _entity_set_path(self._entity_type) + first_url = self._transport.absolute_url(path) + params = self._query.to_params() + if params: + first_url += "?" + urlencode(params) + + iterator = ODataPageIterator( + fetch_page=lambda url: self._transport.request( + GET, _strip_base(url, self._transport._base_url) + ), + entity_type=self._entity_type, + first_url=first_url, + ) + yield from iterator + + def iterate_entities(self) -> Iterator[T]: + """Yield individual entities across all pages.""" + for page in self.iterate_pages(): + yield from page + + +def _strip_base(url: str, base_url: str) -> str: + """Strip *base_url* prefix from *url* to get a relative path.""" + prefix = base_url + "/" + if url.startswith(prefix): + return url[len(prefix) :] + return url + + +class GetByKeyRequestBuilder(Generic[T]): + """Fluent builder for a single-entity (GetByKey) request.""" + + def __init__( + self, + transport: "ODataHttpTransport", + entity_type: type[T], + key: dict[str, Any], + ) -> None: + self._transport = transport + self._entity_type = entity_type + self._key = key + self._query = StructuredQuery() + + def select(self, *fields: str) -> "GetByKeyRequestBuilder[T]": + self._query = self._query.select(*fields) + return self + + def expand(self, *nav_properties: str) -> "GetByKeyRequestBuilder[T]": + self._query = self._query.expand(*nav_properties) + return self + + @record_metrics(Module.ODATA, Operation.ODATA_GET_BY_KEY) + def execute(self) -> T: + """Fetch the entity, raising :exc:`ODataNotFoundError` if absent.""" + path = _entity_set_path(self._entity_type) + _build_key_segment(self._key) + data = self._transport.request(GET, path, params=self._query.to_params()) + return deserialize_single(data, self._entity_type) + + +class CreateRequestBuilder(Generic[T]): + """Builder for OData entity creation (POST).""" + + def __init__(self, transport: "ODataHttpTransport", entity: T) -> None: + self._transport = transport + self._entity = entity + + @record_metrics(Module.ODATA, Operation.ODATA_CREATE) + def execute(self) -> T: + """Create the entity and return the server response as the same type.""" + entity_type = type(self._entity) + path = _entity_set_path(entity_type) + e: Any = cast(Any, self._entity) + body = e.to_dict() if hasattr(e, "to_dict") else dataclasses.asdict(e) + data = self._transport.request(POST, path, json=body) + return deserialize_single(data, entity_type) + + +class UpdateRequestBuilder(Generic[T]): + """Builder for OData entity update (PATCH by default, PUT when ``.replace()`` called).""" + + def __init__( + self, + transport: "ODataHttpTransport", + entity: T, + etag: str | None = None, + ) -> None: + self._transport = transport + self._entity = entity + self._use_put = False + self._etag = etag + + def replace(self) -> "UpdateRequestBuilder[T]": + """Switch from PATCH (default) to PUT (full replacement).""" + self._use_put = True + return self + + @record_metrics(Module.ODATA, Operation.ODATA_UPDATE) + def execute(self) -> T: + """Send the update and return the server response.""" + entity_type = type(self._entity) + key_fields: list[str] = getattr(entity_type, "_key_fields", []) + if not key_fields: + raise ValueError( + f"{entity_type.__name__} does not define _key_fields; " + "cannot build a key path for update" + ) + key = {k: getattr(self._entity, k) for k in key_fields} + path = _entity_set_path(entity_type) + _build_key_segment(key) + e: Any = cast(Any, self._entity) + body = e.to_dict() if hasattr(e, "to_dict") else dataclasses.asdict(e) + method = PUT if self._use_put else PATCH + extra: dict[str, str] = {} + if self._etag is not None: + extra[IF_MATCH_HEADER] = self._etag + data = self._transport.request(method, path, json=body, headers=extra or None) + if not data: + return self._entity + return deserialize_single(data, entity_type) + + +class DeleteRequestBuilder(Generic[T]): + """Builder for OData entity deletion (DELETE).""" + + def __init__( + self, + transport: "ODataHttpTransport", + entity_type: type[T], + key: dict[str, Any], + etag: str | None = None, + ) -> None: + self._transport = transport + self._entity_type = entity_type + self._key = key + self._etag = etag + + @record_metrics(Module.ODATA, Operation.ODATA_DELETE) + def execute(self) -> None: + """Delete the entity.""" + path = _entity_set_path(self._entity_type) + _build_key_segment(self._key) + extra: dict[str, str] = {} + if self._etag is not None: + extra[IF_MATCH_HEADER] = self._etag + self._transport.request(DELETE, path, headers=extra or None) diff --git a/src/sap_cloud_sdk/core/odata/_response.py b/src/sap_cloud_sdk/core/odata/_response.py new file mode 100644 index 0000000..ba7f63a --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_response.py @@ -0,0 +1,63 @@ +"""OData v4 response parsing and entity deserialisation.""" + +from __future__ import annotations + +import dataclasses +from typing import Any, TypeVar + +from sap_cloud_sdk.core.odata._constants import RESPONSE_NEXT_LINK, RESPONSE_VALUE +from sap_cloud_sdk.core.odata.exceptions import ODataDeserializationError + +T = TypeVar("T") + + +def deserialize_single(data: dict[str, Any], entity_type: type[T]) -> T: + """Deserialise a single OData entity dict into *entity_type*. + + Accepts both a raw entity dict and an OData response envelope + (``{"value": {...}}``). Unknown fields in *data* are silently ignored so + that server-side ``@odata.*`` annotations do not break deserialisation. + """ + if not dataclasses.is_dataclass(entity_type): + raise ODataDeserializationError( + f"{entity_type!r} is not a dataclass — cannot deserialize" + ) + try: + payload = ( + data.get(RESPONSE_VALUE, data) + if isinstance(data.get(RESPONSE_VALUE), dict) + else data + ) + known = {f.name for f in dataclasses.fields(entity_type)} # type: ignore[arg-type] + kwargs = {k: v for k, v in payload.items() if k in known} + return entity_type(**kwargs) # type: ignore[call-arg] + except Exception as exc: + raise ODataDeserializationError( + f"Failed to deserialize {entity_type.__name__}: {exc}" + ) from exc + + +def deserialize_collection(data: dict[str, Any], entity_type: type[T]) -> list[T]: + """Deserialise an OData collection response into a list of *entity_type*. + + Expects ``{"value": [...]}`` envelope. Returns an empty list when the + ``value`` key is absent. + """ + if not dataclasses.is_dataclass(entity_type): + raise ODataDeserializationError( + f"{entity_type!r} is not a dataclass — cannot deserialize" + ) + try: + items: list[dict[str, Any]] = data.get(RESPONSE_VALUE, []) + return [deserialize_single(item, entity_type) for item in items] + except ODataDeserializationError: + raise + except Exception as exc: + raise ODataDeserializationError( + f"Failed to deserialize collection of {entity_type.__name__}: {exc}" + ) from exc + + +def next_link(data: dict[str, Any]) -> str | None: + """Extract ``@odata.nextLink`` from a collection response, or ``None``.""" + return data.get(RESPONSE_NEXT_LINK) diff --git a/src/sap_cloud_sdk/core/odata/_transport.py b/src/sap_cloud_sdk/core/odata/_transport.py new file mode 100644 index 0000000..a278f4f --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/_transport.py @@ -0,0 +1,157 @@ +"""Synchronous HTTP transport for OData v4 services.""" + +from __future__ import annotations + +import logging +from typing import Any + +import requests +from requests.exceptions import RequestException + +from sap_cloud_sdk.core.odata._constants import ( + CSRF_HEADER, + DEFAULT_HEADERS, + MUTATING_METHODS, + REQUEST_TIMEOUT, +) +from sap_cloud_sdk.core.odata._csrf import CsrfTokenProvider +from sap_cloud_sdk.core.odata.exceptions import ( + ODataAuthError, + ODataConnectionError, + ODataNotFoundError, + ODataRequestError, +) + +logger = logging.getLogger(__name__) + + +class ODataHttpTransport: + """Reusable synchronous HTTP transport for OData v4 services. + + Owns the ``requests.Session``, JSON serialisation, CSRF token handling, + and status-code–to-exception mapping. Designed to be injected into + request builders. + + Args: + base_url: Root URL of the OData service + (e.g. ``https://host/sap/opu/odata4/svc/``). + session: Pre-configured ``requests.Session`` (auth headers set by + the caller, e.g. via an OAuth2 adapter or a destination factory). + csrf_enabled: Whether to fetch and attach CSRF tokens on mutating + requests. Set to ``False`` for services that do not require it. + + Example:: + + transport = ODataHttpTransport( + base_url="https://example.com/odata/v4/", + session=oauth_session, + ) + data = transport.request("GET", "BusinessPartnerSet", params={"$top": "10"}) + """ + + def __init__( + self, + base_url: str, + session: requests.Session, + csrf_enabled: bool = True, + ) -> None: + self._base_url = base_url.rstrip("/") + self._session = session + self._csrf: CsrfTokenProvider | None = ( + CsrfTokenProvider(self) if csrf_enabled else None + ) + + def request( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + json: Any | None = None, + headers: dict[str, str] | None = None, + ) -> dict[str, Any]: + """Execute an OData request and return the parsed JSON body. + + CSRF tokens are fetched and attached automatically for mutating methods + (POST, PUT, PATCH, DELETE) when ``csrf_enabled`` is ``True``. On a + 403 response the cached token is invalidated and the request is retried + once with a fresh token. + + Args: + method: HTTP method (``"GET"``, ``"POST"``, ``"PATCH"``, etc.). + path: Entity path relative to the service base URL. + params: OData query parameters (``$filter``, ``$top``, …). + json: Request body serialised as JSON. + headers: Extra headers merged on top of the defaults + (``Accept: application/json``, ``Content-Type: application/json``). + + Returns: + Parsed JSON response body, or ``{}`` for 204 / empty responses. + """ + extra = dict(headers or {}) + + if method.upper() in MUTATING_METHODS and self._csrf is not None: + extra[CSRF_HEADER] = self._csrf.get() + try: + return self._execute( + method, path, params=params, json=json, extra_headers=extra + ) + except ODataAuthError as exc: + if exc.status_code == 403: + self._csrf.invalidate() + extra[CSRF_HEADER] = self._csrf.get() + return self._execute( + method, path, params=params, json=json, extra_headers=extra + ) + raise + + return self._execute( + method, path, params=params, json=json, extra_headers=extra + ) + + def absolute_url(self, path: str) -> str: + """Return the full URL for *path* relative to the service base.""" + return self._base_url + "/" + path.lstrip("/") + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + def _execute( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + json: Any | None = None, + extra_headers: dict[str, str] | None = None, + ) -> dict[str, Any]: + url = self.absolute_url(path) + req_headers = {**DEFAULT_HEADERS, **(extra_headers or {})} + + logger.debug("%s %s params=%s", method, url, params) + try: + resp = self._session.request( + method=method, + url=url, + headers=req_headers, + params=params, + json=json, + timeout=REQUEST_TIMEOUT, + ) + except RequestException as exc: + raise ODataConnectionError(str(exc)) from exc + + self._raise_for_status(resp) + + if resp.status_code == 204 or not resp.content: + return {} + return resp.json() + + def _raise_for_status(self, response: requests.Response) -> None: + if response.status_code == 404: + raise ODataNotFoundError(response) + if response.status_code in (401, 403): + raise ODataAuthError(response) + if not response.ok: + raise ODataRequestError(response) diff --git a/src/sap_cloud_sdk/core/odata/exceptions.py b/src/sap_cloud_sdk/core/odata/exceptions.py new file mode 100644 index 0000000..1490122 --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/exceptions.py @@ -0,0 +1,45 @@ +"""OData-specific exception hierarchy.""" + +from __future__ import annotations + +from typing import Any + + +class ODataError(Exception): + """Base for all OData-related errors.""" + + +class ODataRequestError(ODataError): + """HTTP-level error from an OData service (non-2xx response).""" + + def __init__(self, response: Any) -> None: + self.status_code: int = response.status_code + self.response = response + try: + body = response.json() + err = body.get("error") or {} + detail = err.get("message") or err.get("code") + except Exception: + detail = None + suffix = f" — {detail}" if detail else "" + super().__init__(f"OData request failed: HTTP {response.status_code}{suffix}") + + +class ODataNotFoundError(ODataRequestError): + """Entity not found (HTTP 404).""" + + +class ODataAuthError(ODataRequestError): + """Authentication or authorization failure (HTTP 401/403).""" + + +class ODataDeserializationError(ODataError): + """Failed to deserialize an OData response payload.""" + + +class ODataConnectionError(ODataError): + """Network-level error reaching an OData service (no HTTP response received).""" + + +class ODataCsrfError(ODataError): + """Failed to fetch or validate a CSRF token.""" diff --git a/src/sap_cloud_sdk/core/odata/py.typed b/src/sap_cloud_sdk/core/odata/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/sap_cloud_sdk/core/odata/user-guide.md b/src/sap_cloud_sdk/core/odata/user-guide.md new file mode 100644 index 0000000..09f293f --- /dev/null +++ b/src/sap_cloud_sdk/core/odata/user-guide.md @@ -0,0 +1,282 @@ +# OData Core User Guide + +This module provides shared, reusable OData v4 building blocks for service modules and generated clients within the SAP Cloud SDK for Python. + +It is an internal package (`core/odata`). Import from it directly: + +```python +from sap_cloud_sdk.core.odata._transport import ODataHttpTransport +from sap_cloud_sdk.core.odata._request_builders import GetAllRequestBuilder +from sap_cloud_sdk.core.odata._filter import FilterExpression +from sap_cloud_sdk.core.odata._query import StructuredQuery, OrderDirection +from sap_cloud_sdk.core.odata._factory import odata_transport_from_destination +from sap_cloud_sdk.core.odata.exceptions import ODataError, ODataNotFoundError +``` + +## Destination integration + +`odata_transport_from_destination` builds an `ODataHttpTransport` from a resolved BTP Destination. The destination's auth tokens and ERP headers are pre-baked into the underlying session, so no manual header management is needed. + +```python +from sap_cloud_sdk.destination import create_client +from sap_cloud_sdk.core.odata._factory import odata_transport_from_destination +from sap_cloud_sdk.core.odata._request_builders import GetAllRequestBuilder + +dest_client = create_client() +destination = dest_client.get_destination("S4HANA_OData") + +transport = odata_transport_from_destination(destination) +results = GetAllRequestBuilder(transport, BusinessPartner).top(10).execute() +``` + +When the destination URL points to the host root rather than the OData service root, pass `odata_path`: + +```python +transport = odata_transport_from_destination( + destination, + odata_path="sap/opu/odata4/svc/API_BUSINESS_PARTNER/", +) +``` + +Set `csrf_enabled=False` for services that do not require CSRF tokens: + +```python +transport = odata_transport_from_destination(destination, csrf_enabled=False) +``` + +## Transport + +`ODataHttpTransport` wraps a `requests.Session` and handles JSON serialisation, CSRF token fetch-and-retry, and status-code–to-exception mapping. + +```python +import requests +from sap_cloud_sdk.core.odata._transport import ODataHttpTransport + +session = requests.Session() +session.headers["Authorization"] = "Bearer " + +transport = ODataHttpTransport( + base_url="https://host/sap/opu/odata4/svc/", + session=session, +) +``` + +Set `csrf_enabled=False` for services that do not require CSRF tokens: + +```python +transport = ODataHttpTransport(base_url="...", session=session, csrf_enabled=False) +``` + +### Async transport + +`AsyncODataHttpTransport` mirrors the sync interface using `httpx.AsyncClient`: + +```python +import httpx +from sap_cloud_sdk.core.odata._async_transport import AsyncODataHttpTransport + +async with AsyncODataHttpTransport( + base_url="https://host/sap/opu/odata4/svc/", + client=httpx.AsyncClient(headers={"Authorization": "Bearer "}), +) as transport: + data = await transport.request("GET", "EntitySet", params={"$top": "10"}) +``` + +## Request Builders + +Request builders compose a transport, an entity type, and optional query options into a typed, fluent API. + +### GetAllRequestBuilder + +```python +from sap_cloud_sdk.core.odata._request_builders import GetAllRequestBuilder +from sap_cloud_sdk.core.odata._filter import FilterExpression +from sap_cloud_sdk.core.odata._query import OrderDirection + +results = ( + GetAllRequestBuilder(transport, BusinessPartner) + .select("BusinessPartnerID", "DisplayName") + .filter(FilterExpression.field("DisplayName").contains("Acme")) + .order_by("DisplayName", OrderDirection.ASC) + .top(50) + .execute() +) +``` + +### GetByKeyRequestBuilder + +```python +from sap_cloud_sdk.core.odata._request_builders import GetByKeyRequestBuilder + +partner = ( + GetByKeyRequestBuilder(transport, BusinessPartner, {"BusinessPartnerID": "1000001"}) + .expand("ToAddresses") + .execute() +) +``` + +Raises `ODataNotFoundError` when the entity does not exist. + +### CreateRequestBuilder + +```python +from sap_cloud_sdk.core.odata._request_builders import CreateRequestBuilder + +new_partner = BusinessPartner(BusinessPartnerID="", DisplayName="New Corp") +created = CreateRequestBuilder(transport, new_partner).execute() +``` + +### UpdateRequestBuilder + +PATCH (partial update) by default; call `.replace()` to switch to PUT: + +```python +from sap_cloud_sdk.core.odata._request_builders import UpdateRequestBuilder + +# PATCH +updated = UpdateRequestBuilder(transport, partner).execute() + +# PUT — full replacement +updated = UpdateRequestBuilder(transport, partner).replace().execute() + +# With ETag for optimistic locking +updated = UpdateRequestBuilder(transport, partner, etag='"W/\\"1234\\""').execute() +``` + +### DeleteRequestBuilder + +```python +from sap_cloud_sdk.core.odata._request_builders import DeleteRequestBuilder + +DeleteRequestBuilder(transport, BusinessPartner, {"BusinessPartnerID": "1000001"}).execute() +``` + +## FilterExpression + +Build `$filter` expressions without string manipulation: + +```python +from sap_cloud_sdk.core.odata._filter import FilterExpression + +# Simple comparison +f = FilterExpression.field("Price").gt(100) +str(f) # "Price gt 100" + +# Combine with and_ / or_ / not_ +f = ( + FilterExpression.field("Price").gt(100) + .and_(FilterExpression.field("Category").eq("Books")) +) +str(f) # "(Price gt 100) and (Category eq 'Books')" + +# String functions +f = FilterExpression.field("Name").contains("Acme") +str(f) # "contains(Name, 'Acme')" +``` + +Available operators: `eq`, `ne`, `lt`, `le`, `gt`, `ge`, `contains`, `starts_with`, `ends_with`. + +## StructuredQuery + +Immutable query builder — each method returns a new instance, safe to share: + +```python +from sap_cloud_sdk.core.odata._query import StructuredQuery, OrderDirection + +base = StructuredQuery().select("ID", "Name").top(20) +page1 = base.skip(0) +page2 = base.skip(20) + +page1.to_params() +# {"$select": "ID,Name", "$top": "20", "$skip": "0"} +``` + +Pass the result directly to a transport call or a request builder. + +## Pagination + +Server-driven pagination via `@odata.nextLink` is built into `GetAllRequestBuilder`: + +```python +# Yield pages lazily +for page in builder.iterate_pages(): + for entity in page: + process(entity) + +# Or flatten across all pages +for entity in builder.iterate_entities(): + process(entity) +``` + +`ODataPageIterator` can also be used directly when you manage the transport call yourself: + +```python +from sap_cloud_sdk.core.odata._pagination import ODataPageIterator + +iterator = ODataPageIterator( + fetch_page=lambda url: transport.request("GET", url.removeprefix(transport._base_url + "/")), + entity_type=BusinessPartner, + first_url=transport.absolute_url("BusinessPartnerSet?$top=100"), +) +for page in iterator: + ... +``` + +## Entity Model + +Entity dataclasses declare metadata via `ClassVar` annotations so the transport layer can reflect on key fields and entity-set names: + +```python +from dataclasses import dataclass +from typing import ClassVar +from sap_cloud_sdk.core.odata._models import ODataEntity + +@dataclass +class BusinessPartner(ODataEntity): + _entity_set: ClassVar[str] = "BusinessPartnerSet" + _key_fields: ClassVar[list[str]] = ["BusinessPartnerID"] + + BusinessPartnerID: str = "" + DisplayName: str = "" +``` + +Plain dataclasses (without `ODataEntity`) also work — `_entity_set` defaults to the class name and request builders that need key fields will raise `ValueError` if `_key_fields` is absent. + +## Error Handling + +```python +from sap_cloud_sdk.core.odata.exceptions import ( + ODataError, + ODataRequestError, + ODataNotFoundError, + ODataAuthError, + ODataDeserializationError, + ODataCsrfError, +) + +try: + partner = GetByKeyRequestBuilder(transport, BusinessPartner, key).execute() +except ODataNotFoundError: + ... +except ODataAuthError as e: + print(f"Auth failure (HTTP {e.status_code})") +except ODataCsrfError as e: + print(f"CSRF handshake failed: {e}") +except ODataRequestError as e: + print(f"Service error (HTTP {e.status_code}): {e}") +except ODataDeserializationError as e: + print(f"Could not parse response: {e}") +except ODataError: + ... +``` + +Exception hierarchy: + +``` +ODataError +├── ODataRequestError # non-2xx HTTP response +│ ├── ODataNotFoundError # 404 +│ └── ODataAuthError # 401 / 403 +├── ODataDeserializationError +└── ODataCsrfError +``` diff --git a/src/sap_cloud_sdk/core/telemetry/module.py b/src/sap_cloud_sdk/core/telemetry/module.py index d1f605f..8c83a7f 100644 --- a/src/sap_cloud_sdk/core/telemetry/module.py +++ b/src/sap_cloud_sdk/core/telemetry/module.py @@ -17,6 +17,7 @@ class Module(str, Enum): DMS = "dms" EXTENSIBILITY = "extensibility" OBJECTSTORE = "objectstore" + ODATA = "odata" PRINT = "print" TELEMETRY = "telemetry" diff --git a/src/sap_cloud_sdk/core/telemetry/operation.py b/src/sap_cloud_sdk/core/telemetry/operation.py index e44ab64..0423a9f 100644 --- a/src/sap_cloud_sdk/core/telemetry/operation.py +++ b/src/sap_cloud_sdk/core/telemetry/operation.py @@ -201,5 +201,12 @@ class Operation(str, Enum): AGENT_MEMORY_GET_RETENTION_CONFIG = "get_retention_config" AGENT_MEMORY_UPDATE_RETENTION_CONFIG = "update_retention_config" + # OData Operations + ODATA_GET_ALL = "odata_get_all" + ODATA_GET_BY_KEY = "odata_get_by_key" + ODATA_CREATE = "odata_create" + ODATA_UPDATE = "odata_update" + ODATA_DELETE = "odata_delete" + def __str__(self) -> str: return self.value diff --git a/tests/core/unit/odata/__init__.py b/tests/core/unit/odata/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/unit/odata/test_csrf.py b/tests/core/unit/odata/test_csrf.py new file mode 100644 index 0000000..e8b7dc2 --- /dev/null +++ b/tests/core/unit/odata/test_csrf.py @@ -0,0 +1,77 @@ +"""Unit tests for CsrfTokenProvider.""" + +from unittest.mock import MagicMock + +import pytest +import requests + +from sap_cloud_sdk.core.odata._csrf import CsrfTokenProvider +from sap_cloud_sdk.core.odata._transport import ODataHttpTransport +from sap_cloud_sdk.core.odata.exceptions import ODataCsrfError + + +def _make_transport(session: requests.Session) -> ODataHttpTransport: + return ODataHttpTransport( + base_url="https://example.com/odata/v4", + session=session, + csrf_enabled=False, # we manage CsrfTokenProvider manually in these tests + ) + + +class TestCsrfTokenProvider: + def test_returns_token_from_response_header(self): + session = MagicMock(spec=requests.Session) + resp = MagicMock(spec=requests.Response) + resp.status_code = 200 + resp.headers = {"X-CSRF-Token": "my-token"} + session.get.return_value = resp + + provider = CsrfTokenProvider(_make_transport(session)) + assert provider.get() == "my-token" + + def test_caches_token(self): + session = MagicMock(spec=requests.Session) + resp = MagicMock(spec=requests.Response) + resp.status_code = 200 + resp.headers = {"X-CSRF-Token": "cached"} + session.get.return_value = resp + + provider = CsrfTokenProvider(_make_transport(session)) + provider.get() + provider.get() + assert session.get.call_count == 1 + + def test_invalidate_clears_cache(self): + session = MagicMock(spec=requests.Session) + resp1 = MagicMock(spec=requests.Response) + resp1.status_code = 200 + resp1.headers = {"X-CSRF-Token": "tok1"} + resp2 = MagicMock(spec=requests.Response) + resp2.status_code = 200 + resp2.headers = {"X-CSRF-Token": "tok2"} + session.get.side_effect = [resp1, resp2] + + provider = CsrfTokenProvider(_make_transport(session)) + assert provider.get() == "tok1" + provider.invalidate() + assert provider.get() == "tok2" + assert session.get.call_count == 2 + + def test_raises_csrf_error_when_no_token_in_response(self): + session = MagicMock(spec=requests.Session) + resp = MagicMock(spec=requests.Response) + resp.status_code = 200 + resp.headers = {} + session.get.return_value = resp + + provider = CsrfTokenProvider(_make_transport(session)) + with pytest.raises(ODataCsrfError): + provider.get() + + def test_raises_csrf_error_on_network_failure(self): + session = MagicMock(spec=requests.Session) + session.get.side_effect = requests.RequestException("timeout") + + provider = CsrfTokenProvider(_make_transport(session)) + with pytest.raises(ODataCsrfError, match="CSRF fetch failed"): + provider.get() diff --git a/tests/core/unit/odata/test_factory.py b/tests/core/unit/odata/test_factory.py new file mode 100644 index 0000000..f3d980e --- /dev/null +++ b/tests/core/unit/odata/test_factory.py @@ -0,0 +1,74 @@ +"""Unit tests for odata_transport_from_destination factory.""" + +from unittest.mock import MagicMock + +import pytest + +from sap_cloud_sdk.core.odata._factory import odata_transport_from_destination +from sap_cloud_sdk.core.odata._transport import ODataHttpTransport +from sap_cloud_sdk.destination._models import Destination, DestinationType + + +def _make_destination( + url: str = "https://s4hana.example.com", + type_: DestinationType = DestinationType.HTTP, + headers: dict | None = None, +) -> Destination: + dest = MagicMock(spec=Destination) + dest.url = url + dest.type = type_ + dest.name = "test-destination" + dest.get_headers.return_value = headers or {"Authorization": "Bearer tok"} + return dest + + +class TestOdataTransportFromDestination: + def test_returns_odata_http_transport(self): + transport = odata_transport_from_destination(_make_destination()) + assert isinstance(transport, ODataHttpTransport) + + def test_base_url_is_destination_url(self): + transport = odata_transport_from_destination( + _make_destination(url="https://host.example.com") + ) + assert transport._base_url == "https://host.example.com" + + def test_odata_path_appended_to_base_url(self): + transport = odata_transport_from_destination( + _make_destination(url="https://host.example.com"), + odata_path="sap/opu/odata4/svc", + ) + assert transport._base_url == "https://host.example.com/sap/opu/odata4/svc" + + def test_trailing_slash_stripped_from_destination_url(self): + transport = odata_transport_from_destination( + _make_destination(url="https://host.example.com/") + ) + assert transport._base_url == "https://host.example.com" + + def test_destination_headers_baked_into_session(self): + dest = _make_destination(headers={"Authorization": "Bearer abc", "sap-client": "100"}) + transport = odata_transport_from_destination(dest) + assert transport._session.headers["Authorization"] == "Bearer abc" + assert transport._session.headers["sap-client"] == "100" + + def test_csrf_enabled_by_default(self): + transport = odata_transport_from_destination(_make_destination()) + assert transport._csrf is not None + + def test_csrf_disabled_when_requested(self): + transport = odata_transport_from_destination( + _make_destination(), csrf_enabled=False + ) + assert transport._csrf is None + + def test_raises_for_non_http_destination(self): + dest = _make_destination(type_=DestinationType.RFC) + with pytest.raises(ValueError, match="HTTP destinations"): + odata_transport_from_destination(dest) + + def test_raises_when_destination_has_no_url(self): + dest = _make_destination(url="") + dest.url = None + with pytest.raises(ValueError, match="no URL"): + odata_transport_from_destination(dest) diff --git a/tests/core/unit/odata/test_filter_expression.py b/tests/core/unit/odata/test_filter_expression.py new file mode 100644 index 0000000..15d7bab --- /dev/null +++ b/tests/core/unit/odata/test_filter_expression.py @@ -0,0 +1,100 @@ +"""Unit tests for FilterExpression DSL.""" + +import pytest +from sap_cloud_sdk.core.odata._filter import FilterExpression, _format_value + + +class TestFormatValue: + def test_string_single_quotes(self): + assert _format_value("hello") == "'hello'" + + def test_string_escapes_embedded_quotes(self): + assert _format_value("O'Brien") == "'O''Brien'" + + def test_integer(self): + assert _format_value(42) == "42" + + def test_float(self): + assert _format_value(3.14) == "3.14" + + def test_bool_true(self): + assert _format_value(True) == "true" + + def test_bool_false(self): + assert _format_value(False) == "false" + + +class TestFieldRef: + def test_eq_string(self): + f = FilterExpression.field("Name").eq("Acme") + assert str(f) == "Name eq 'Acme'" + + def test_ne(self): + assert str(FilterExpression.field("Status").ne("X")) == "Status ne 'X'" + + def test_lt(self): + assert str(FilterExpression.field("Price").lt(100)) == "Price lt 100" + + def test_le(self): + assert str(FilterExpression.field("Price").le(100)) == "Price le 100" + + def test_gt(self): + assert str(FilterExpression.field("Price").gt(0)) == "Price gt 0" + + def test_ge(self): + assert str(FilterExpression.field("Price").ge(1)) == "Price ge 1" + + def test_contains(self): + assert ( + str(FilterExpression.field("Name").contains("Acme")) + == "contains(Name, 'Acme')" + ) + + def test_starts_with(self): + assert ( + str(FilterExpression.field("Name").starts_with("A")) + == "startswith(Name, 'A')" + ) + + def test_ends_with(self): + assert ( + str(FilterExpression.field("Name").ends_with("Corp")) + == "endswith(Name, 'Corp')" + ) + + +class TestFilterExpressionComposition: + def test_and_(self): + f = ( + FilterExpression.field("Price").gt(100) + .and_(FilterExpression.field("Category").eq("Books")) + ) + assert str(f) == "(Price gt 100) and (Category eq 'Books')" + + def test_or_(self): + f = ( + FilterExpression.field("Status").eq("A") + .or_(FilterExpression.field("Status").eq("B")) + ) + assert str(f) == "(Status eq 'A') or (Status eq 'B')" + + def test_not_(self): + f = FilterExpression.field("Deleted").eq(True).not_() + assert str(f) == "not (Deleted eq true)" + + def test_chained_and_or(self): + f = ( + FilterExpression.field("A").eq(1) + .and_(FilterExpression.field("B").eq(2)) + .or_(FilterExpression.field("C").eq(3)) + ) + assert str(f) == "((A eq 1) and (B eq 2)) or (C eq 3)" + + def test_equality(self): + a = FilterExpression.field("X").eq(1) + b = FilterExpression.field("X").eq(1) + assert a == b + + def test_hash_consistency(self): + a = FilterExpression.field("X").eq(1) + assert hash(a) == hash(FilterExpression.field("X").eq(1)) diff --git a/tests/core/unit/odata/test_pagination.py b/tests/core/unit/odata/test_pagination.py new file mode 100644 index 0000000..4acc30f --- /dev/null +++ b/tests/core/unit/odata/test_pagination.py @@ -0,0 +1,79 @@ +"""Unit tests for ODataPageIterator.""" + +from typing import Any + +import pytest + +from sap_cloud_sdk.core.odata._pagination import ODataPageIterator +from sap_cloud_sdk.core.odata.exceptions import ODataDeserializationError + + +from dataclasses import dataclass + + +@dataclass +class _Item: + id: str = "" + name: str = "" + + +class TestODataPageIterator: + def test_single_page_no_next_link(self): + pages = [{"value": [{"id": "1", "name": "A"}, {"id": "2", "name": "B"}]}] + urls: list[str] = [] + + def fetch(url: str) -> dict[str, Any]: + urls.append(url) + return pages.pop(0) + + iterator = ODataPageIterator(fetch, _Item, "https://host/svc/Items") + result = list(iterator) + + assert len(result) == 1 + assert result[0] == [_Item(id="1", name="A"), _Item(id="2", name="B")] + assert urls == ["https://host/svc/Items"] + + def test_multi_page_follows_next_link(self): + responses = [ + { + "value": [{"id": "1", "name": "A"}], + "@odata.nextLink": "https://host/svc/Items?$skip=1", + }, + {"value": [{"id": "2", "name": "B"}]}, + ] + + def fetch(url: str) -> dict[str, Any]: + return responses.pop(0) + + iterator = ODataPageIterator(fetch, _Item, "https://host/svc/Items") + pages = list(iterator) + assert len(pages) == 2 + assert pages[0] == [_Item(id="1", name="A")] + assert pages[1] == [_Item(id="2", name="B")] + + def test_entities_yields_individual_items(self): + responses = [ + { + "value": [{"id": "1", "name": "A"}, {"id": "2", "name": "B"}], + "@odata.nextLink": "https://host/next", + }, + {"value": [{"id": "3", "name": "C"}]}, + ] + + def fetch(url: str) -> dict[str, Any]: + return responses.pop(0) + + iterator = ODataPageIterator(fetch, _Item, "https://host/svc/Items") + entities = list(iterator.entities()) + assert entities == [ + _Item(id="1", name="A"), + _Item(id="2", name="B"), + _Item(id="3", name="C"), + ] + + def test_empty_collection(self): + def fetch(url: str) -> dict[str, Any]: + return {"value": []} + + iterator = ODataPageIterator(fetch, _Item, "https://host/svc/Items") + assert list(iterator) == [[]] diff --git a/tests/core/unit/odata/test_request_builders.py b/tests/core/unit/odata/test_request_builders.py new file mode 100644 index 0000000..0092605 --- /dev/null +++ b/tests/core/unit/odata/test_request_builders.py @@ -0,0 +1,221 @@ +"""Unit tests for CRUD request builders.""" + +from dataclasses import dataclass +from typing import Any, ClassVar +from unittest.mock import MagicMock + +import pytest +import requests + +from sap_cloud_sdk.core.odata._filter import FilterExpression +from sap_cloud_sdk.core.odata._models import ODataEntity +from sap_cloud_sdk.core.odata._request_builders import ( + CreateRequestBuilder, + DeleteRequestBuilder, + GetAllRequestBuilder, + GetByKeyRequestBuilder, + UpdateRequestBuilder, + _build_key_segment, + _entity_set_path, +) +from sap_cloud_sdk.core.odata._transport import ODataHttpTransport + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@dataclass +class _Partner(ODataEntity): + _entity_set: ClassVar[str] = "BusinessPartnerSet" + _key_fields: ClassVar[list[str]] = ["PartnerID"] + + PartnerID: str = "" + Name: str = "" + + +def _make_transport(session: requests.Session) -> ODataHttpTransport: + return ODataHttpTransport( + base_url="https://host/svc", + session=session, + csrf_enabled=False, + ) + + +def _mock_response(status_code: int = 200, json_data: Any = None) -> MagicMock: + resp = MagicMock(spec=requests.Response) + resp.status_code = status_code + resp.ok = 200 <= status_code < 300 + resp.content = b"data" + resp.json.return_value = json_data if json_data is not None else {} + resp.headers = {} + return resp + + +# --------------------------------------------------------------------------- +# Helpers tests +# --------------------------------------------------------------------------- + + +class TestBuildKeySegment: + def test_single_string_key(self): + assert _build_key_segment({"ID": "x"}) == "('x')" + + def test_single_int_key(self): + assert _build_key_segment({"ID": 1}) == "(1)" + + def test_composite_key(self): + seg = _build_key_segment({"ID": "x", "Ver": 1}) + assert seg == "(ID='x',Ver=1)" + + +class TestEntitySetPath: + def test_reads_entity_set_classvar(self): + assert _entity_set_path(_Partner) == "BusinessPartnerSet" + + def test_falls_back_to_class_name(self): + @dataclass + class NoMeta: + ID: str = "" + + assert _entity_set_path(NoMeta) == "NoMeta" + + +# --------------------------------------------------------------------------- +# GetAllRequestBuilder +# --------------------------------------------------------------------------- + + +class TestGetAllRequestBuilder: + def test_execute_calls_correct_path(self): + session = MagicMock(spec=requests.Session) + session.request.return_value = _mock_response(200, {"value": [{"PartnerID": "1", "Name": "A"}]}) + transport = _make_transport(session) + + results = GetAllRequestBuilder(transport, _Partner).execute() + + url = session.request.call_args[1]["url"] + assert "BusinessPartnerSet" in url + assert results == [_Partner(PartnerID="1", Name="A")] + + def test_execute_passes_query_params(self): + session = MagicMock(spec=requests.Session) + session.request.return_value = _mock_response(200, {"value": []}) + transport = _make_transport(session) + + ( + GetAllRequestBuilder(transport, _Partner) + .select("PartnerID") + .top(5) + .filter(FilterExpression.field("Name").eq("Acme")) + .execute() + ) + + params = session.request.call_args[1]["params"] + assert params["$select"] == "PartnerID" + assert params["$top"] == "5" + assert params["$filter"] == "Name eq 'Acme'" + + +# --------------------------------------------------------------------------- +# GetByKeyRequestBuilder +# --------------------------------------------------------------------------- + + +class TestGetByKeyRequestBuilder: + def test_execute_builds_key_in_path(self): + session = MagicMock(spec=requests.Session) + session.request.return_value = _mock_response(200, {"PartnerID": "42", "Name": "X"}) + transport = _make_transport(session) + + result = GetByKeyRequestBuilder( + transport, _Partner, {"PartnerID": "42"} + ).execute() + + url = session.request.call_args[1]["url"] + assert "BusinessPartnerSet" in url + assert "'42'" in url + assert result == _Partner(PartnerID="42", Name="X") + + +# --------------------------------------------------------------------------- +# CreateRequestBuilder +# --------------------------------------------------------------------------- + + +class TestCreateRequestBuilder: + def test_execute_posts_entity_and_returns_result(self): + session = MagicMock(spec=requests.Session) + session.request.return_value = _mock_response(201, {"PartnerID": "new", "Name": "New"}) + transport = _make_transport(session) + + entity = _Partner(PartnerID="new", Name="New") + result = CreateRequestBuilder(transport, entity).execute() + + assert session.request.call_args[1]["method"] == "POST" + assert result == _Partner(PartnerID="new", Name="New") + + +# --------------------------------------------------------------------------- +# UpdateRequestBuilder +# --------------------------------------------------------------------------- + + +class TestUpdateRequestBuilder: + def test_execute_patches_by_default(self): + session = MagicMock(spec=requests.Session) + resp = _mock_response(204) + resp.content = b"" + session.request.return_value = resp + transport = _make_transport(session) + + entity = _Partner(PartnerID="1", Name="Updated") + UpdateRequestBuilder(transport, entity).execute() + + assert session.request.call_args[1]["method"] == "PATCH" + + def test_replace_uses_put(self): + session = MagicMock(spec=requests.Session) + resp = _mock_response(204) + resp.content = b"" + session.request.return_value = resp + transport = _make_transport(session) + + entity = _Partner(PartnerID="1", Name="Updated") + UpdateRequestBuilder(transport, entity).replace().execute() + + assert session.request.call_args[1]["method"] == "PUT" + + def test_etag_sent_in_if_match_header(self): + session = MagicMock(spec=requests.Session) + resp = _mock_response(204) + resp.content = b"" + session.request.return_value = resp + transport = _make_transport(session) + + entity = _Partner(PartnerID="1", Name="X") + UpdateRequestBuilder(transport, entity, etag='"W/123"').execute() + + headers = session.request.call_args[1]["headers"] + assert headers["If-Match"] == '"W/123"' + + +# --------------------------------------------------------------------------- +# DeleteRequestBuilder +# --------------------------------------------------------------------------- + + +class TestDeleteRequestBuilder: + def test_execute_sends_delete(self): + session = MagicMock(spec=requests.Session) + resp = _mock_response(204) + resp.content = b"" + session.request.return_value = resp + transport = _make_transport(session) + + DeleteRequestBuilder(transport, _Partner, {"PartnerID": "1"}).execute() + + assert session.request.call_args[1]["method"] == "DELETE" + url = session.request.call_args[1]["url"] + assert "BusinessPartnerSet" in url diff --git a/tests/core/unit/odata/test_structured_query.py b/tests/core/unit/odata/test_structured_query.py new file mode 100644 index 0000000..b62d034 --- /dev/null +++ b/tests/core/unit/odata/test_structured_query.py @@ -0,0 +1,96 @@ +"""Unit tests for StructuredQuery builder.""" + +import pytest +from sap_cloud_sdk.core.odata._query import OrderDirection, StructuredQuery +from sap_cloud_sdk.core.odata._filter import FilterExpression + + +class TestStructuredQueryImmutability: + def test_select_returns_new_instance(self): + q = StructuredQuery() + q2 = q.select("A", "B") + assert q is not q2 + assert q.to_params() == {} + assert "$select" in q2.to_params() + + def test_chained_calls_do_not_mutate_base(self): + base = StructuredQuery().top(10) + page1 = base.skip(0) + page2 = base.skip(10) + assert page1.to_params()["$skip"] == "0" + assert page2.to_params()["$skip"] == "10" + assert "$skip" not in base.to_params() + + +class TestToParams: + def test_empty_query_produces_no_params(self): + assert StructuredQuery().to_params() == {} + + def test_select(self): + params = StructuredQuery().select("ID", "Name").to_params() + assert params["$select"] == "ID,Name" + + def test_top(self): + assert StructuredQuery().top(20).to_params()["$top"] == "20" + + def test_skip(self): + assert StructuredQuery().skip(5).to_params()["$skip"] == "5" + + def test_filter(self): + f = FilterExpression.field("Name").eq("Acme") + params = StructuredQuery().filter(f).to_params() + assert params["$filter"] == "Name eq 'Acme'" + + def test_order_by_asc(self): + params = StructuredQuery().order_by("Name").to_params() + assert params["$orderby"] == "Name asc" + + def test_order_by_desc(self): + params = ( + StructuredQuery() + .order_by("CreatedAt", OrderDirection.DESC) + .to_params() + ) + assert params["$orderby"] == "CreatedAt desc" + + def test_multiple_order_by_fields(self): + params = ( + StructuredQuery() + .order_by("Name", OrderDirection.ASC) + .order_by("CreatedAt", OrderDirection.DESC) + .to_params() + ) + assert params["$orderby"] == "Name asc,CreatedAt desc" + + def test_expand(self): + params = StructuredQuery().expand("ToAddresses", "ToOrders").to_params() + assert params["$expand"] == "ToAddresses,ToOrders" + + def test_custom_param(self): + params = StructuredQuery().custom("sap-language", "en").to_params() + assert params["sap-language"] == "en" + + def test_custom_param_overwrite(self): + q = StructuredQuery().custom("foo", "bar").custom("foo", "baz") + assert q.to_params()["foo"] == "baz" + + def test_full_query(self): + f = FilterExpression.field("Name").eq("Acme") + params = ( + StructuredQuery() + .select("ID", "Name") + .filter(f) + .order_by("Name") + .top(20) + .skip(0) + .expand("ToAddresses") + .to_params() + ) + assert params == { + "$select": "ID,Name", + "$filter": "Name eq 'Acme'", + "$orderby": "Name asc", + "$top": "20", + "$skip": "0", + "$expand": "ToAddresses", + } diff --git a/tests/core/unit/odata/test_transport.py b/tests/core/unit/odata/test_transport.py new file mode 100644 index 0000000..df13bae --- /dev/null +++ b/tests/core/unit/odata/test_transport.py @@ -0,0 +1,151 @@ +"""Unit tests for ODataHttpTransport.""" + +from typing import Any +from unittest.mock import MagicMock + +import pytest +import requests + +from sap_cloud_sdk.core.odata._transport import ODataHttpTransport +from sap_cloud_sdk.core.odata.exceptions import ( + ODataAuthError, + ODataNotFoundError, + ODataRequestError, +) + + +def _mock_response(status_code: int = 200, json_data: Any = None, headers: dict | None = None) -> MagicMock: + resp = MagicMock(spec=requests.Response) + resp.status_code = status_code + resp.ok = 200 <= status_code < 300 + resp.content = b'{"value": []}' if json_data is None else b"data" + resp.json.return_value = json_data if json_data is not None else {} + resp.headers = headers or {} + return resp + + +@pytest.fixture +def session(): + return MagicMock(spec=requests.Session) + + +@pytest.fixture +def transport(session): + return ODataHttpTransport( + base_url="https://example.com/odata/v4/", + session=session, + csrf_enabled=False, + ) + + +class TestRequest: + def test_builds_correct_url(self, transport, session): + session.request.return_value = _mock_response(200, {}) + transport.request("GET", "EntitySet") + assert session.request.call_args[1]["url"] == "https://example.com/odata/v4/EntitySet" + + def test_passes_params(self, transport, session): + session.request.return_value = _mock_response(200, {}) + transport.request("GET", "EntitySet", params={"$top": "5"}) + assert session.request.call_args[1]["params"] == {"$top": "5"} + + def test_sets_accept_header(self, transport, session): + session.request.return_value = _mock_response(200, {}) + transport.request("GET", "EntitySet") + assert session.request.call_args[1]["headers"]["Accept"] == "application/json" + + def test_returns_parsed_json(self, transport, session): + session.request.return_value = _mock_response(200, {"value": [{"ID": "1"}]}) + assert transport.request("GET", "EntitySet") == {"value": [{"ID": "1"}]} + + def test_404_raises_not_found(self, transport, session): + session.request.return_value = _mock_response(404) + with pytest.raises(ODataNotFoundError): + transport.request("GET", "EntitySet") + + def test_401_raises_auth_error(self, transport, session): + session.request.return_value = _mock_response(401) + with pytest.raises(ODataAuthError): + transport.request("GET", "EntitySet") + + def test_500_raises_request_error(self, transport, session): + session.request.return_value = _mock_response(500) + with pytest.raises(ODataRequestError): + transport.request("GET", "EntitySet") + + def test_204_returns_empty_dict(self, transport, session): + resp = _mock_response(204) + resp.content = b"" + session.request.return_value = resp + assert transport.request("GET", "EntitySet") == {} + + def test_extra_headers_merged(self, transport, session): + session.request.return_value = _mock_response(200, {}) + transport.request("GET", "EntitySet", headers={"sap-language": "en"}) + assert session.request.call_args[1]["headers"]["sap-language"] == "en" + + def test_passes_method_verbatim(self, transport, session): + session.request.return_value = _mock_response(201, {"ID": "1"}) + transport.request("POST", "EntitySet", json={"Name": "X"}) + assert session.request.call_args[1]["method"] == "POST" + + +class TestCsrf: + def test_csrf_attached_on_post(self, session): + csrf_resp = MagicMock(spec=requests.Response) + csrf_resp.status_code = 200 + csrf_resp.headers = {"X-CSRF-Token": "csrf-tok"} + session.get.return_value = csrf_resp + session.request.return_value = _mock_response(201, {"ID": "1"}) + + transport = ODataHttpTransport( + base_url="https://example.com/odata/v4/", + session=session, + csrf_enabled=True, + ) + transport.request("POST", "EntitySet", json={"Name": "X"}) + + assert session.request.call_args[1]["headers"]["X-CSRF-Token"] == "csrf-tok" + + def test_no_csrf_on_get(self, session): + session.request.return_value = _mock_response(200, {}) + + transport = ODataHttpTransport( + base_url="https://example.com/odata/v4/", + session=session, + csrf_enabled=True, + ) + transport.request("GET", "EntitySet") + + assert "X-CSRF-Token" not in session.request.call_args[1]["headers"] + session.get.assert_not_called() + + def test_csrf_retry_on_403(self, session): + csrf_resp1 = MagicMock(spec=requests.Response) + csrf_resp1.status_code = 200 + csrf_resp1.headers = {"X-CSRF-Token": "tok1"} + csrf_resp2 = MagicMock(spec=requests.Response) + csrf_resp2.status_code = 200 + csrf_resp2.headers = {"X-CSRF-Token": "tok2"} + + session.get.side_effect = [csrf_resp1, csrf_resp2] + session.request.side_effect = [_mock_response(403), _mock_response(201, {"ID": "1"})] + + transport = ODataHttpTransport( + base_url="https://example.com/odata/v4/", + session=session, + csrf_enabled=True, + ) + transport.request("POST", "EntitySet", json={"Name": "X"}) + + assert session.request.call_count == 2 + + +class TestAbsoluteUrl: + def test_builds_url_with_trailing_slash(self): + t = ODataHttpTransport("https://host/svc/", MagicMock(), csrf_enabled=False) + assert t.absolute_url("EntitySet") == "https://host/svc/EntitySet" + + def test_strips_leading_slash_from_path(self): + t = ODataHttpTransport("https://host/svc", MagicMock(), csrf_enabled=False) + assert t.absolute_url("/EntitySet") == "https://host/svc/EntitySet" diff --git a/tests/core/unit/telemetry/test_module.py b/tests/core/unit/telemetry/test_module.py index ccc5d1d..6a595d9 100644 --- a/tests/core/unit/telemetry/test_module.py +++ b/tests/core/unit/telemetry/test_module.py @@ -55,7 +55,7 @@ def test_module_in_collection(self): def test_all_modules_present(self): """Test that all expected modules are present.""" all_modules = list(Module) - assert len(all_modules) == 13 + assert len(all_modules) == 14 assert Module.ADMS in all_modules assert Module.AGENT_MEMORY in all_modules assert Module.AGENTGATEWAY in all_modules @@ -67,6 +67,7 @@ def test_all_modules_present(self): assert Module.DMS in all_modules assert Module.EXTENSIBILITY in all_modules assert Module.OBJECTSTORE in all_modules + assert Module.ODATA in all_modules assert Module.PRINT in all_modules assert Module.TELEMETRY in all_modules diff --git a/tests/core/unit/telemetry/test_operation.py b/tests/core/unit/telemetry/test_operation.py index ee6db49..3a28ebf 100644 --- a/tests/core/unit/telemetry/test_operation.py +++ b/tests/core/unit/telemetry/test_operation.py @@ -212,5 +212,5 @@ def test_operation_count(self): all_operations = list(Operation) # 3 auditlog + 11 destination + 10 certificate + 10 fragment + 8 objectstore # + 2 extensibility + 2 aicore + 23 dms + 4 agentgateway + 13 agent_memory - # + 5 data_anonymization + 52 adms + 6 print = 149 - assert len(all_operations) == 149 + # + 5 data_anonymization + 52 adms + 6 print + 5 odata = 154 + assert len(all_operations) == 154 diff --git a/uv.lock b/uv.lock index 3d41081..16a5300 100644 --- a/uv.lock +++ b/uv.lock @@ -3696,7 +3696,7 @@ wheels = [ [[package]] name = "sap-cloud-sdk" -version = "0.27.0" +version = "0.28.0" source = { editable = "." } dependencies = [ { name = "grpcio" },