Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/nox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,16 @@ jobs:
run: uv run --group nox nox -e "pytest_min_deps-${{ matrix.python-version }}"
- name: Test with nox with all dependencies
run: uv run --group nox nox -e "pytest_all_deps-${{ matrix.python-version }}"

test-rust-backend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.13"
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Test with nox with the Rust triangulation backend
run: uv run --group nox nox -e pytest_rust_backend
46 changes: 30 additions & 16 deletions adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
circumsphere,
point_in_simplex,
resolve_triangulation_class,
rust_default_loss,
simplex_volume_in_embedding,
)
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
Expand Down Expand Up @@ -333,7 +334,9 @@ def __init__(
):
self._triangulation_class = resolve_triangulation_class(triangulation_backend)
self._vdim = None
self.loss_per_simplex = loss_per_simplex or default_loss
# Prefer the Rust implementation of the default loss when the Rust
# backend is active; it computes the same embedded simplex volume.
self.loss_per_simplex = loss_per_simplex or rust_default_loss or default_loss

if hasattr(self.loss_per_simplex, "nth_neighbors"):
if self.loss_per_simplex.nth_neighbors > 1:
Expand Down Expand Up @@ -649,35 +652,46 @@ def tell_pending(self, point, *, simplex=None):
if self.tri is None:
return

for simpl in self._simplices_containing_point(point, simplex):
_, to_add = self._add_pending_point_to_simplex(point, simpl)
if to_add is None:
continue
self._update_subsimplex_losses(simpl, to_add)

def _simplices_containing_point(self, point, simplex=None):
"""All simplices of the triangulation containing `point`, found from
the `simplex` hint when given."""
if hasattr(self.tri, "simplices_containing"):
# Rust backend: one call instead of a point_in_simplex loop.
return self.tri.simplices_containing(point, simplex=simplex)

simplex = tuple(simplex or self.tri.locate_point(point))
if not simplex:
return
# Simplex is None if pending point is outside the triangulation,
# then you do not have subtriangles
return []
# Simplex is empty if the pending point is outside the
# triangulation, then you do not have subtriangles

simplex = tuple(simplex)
simplices = [self.tri.vertex_to_simplices[i] for i in simplex]
neighbors = set.union(*simplices)
# Neighbours also includes the simplex itself
return [s for s in neighbors if self.tri.point_in_simplex(point, s)]

for simpl in neighbors:
_, to_add = self._try_adding_pending_point_to_simplex(point, simpl)
if to_add is None:
continue
self._update_subsimplex_losses(simpl, to_add)

def _try_adding_pending_point_to_simplex(self, point, simplex):
# try to insert it
if not self.tri.point_in_simplex(point, simplex):
return None, None

def _add_pending_point_to_simplex(self, point, simplex):
"""Insert `point` into the subtriangulation of `simplex`, which must
contain the point."""
if simplex not in self._subtriangulations:
vertices = self.tri.get_vertices(simplex)
self._subtriangulations[simplex] = self._triangulation_class(vertices)

self._pending_to_simplex[point] = simplex
return self._subtriangulations[simplex].add_point(point)

def _try_adding_pending_point_to_simplex(self, point, simplex):
# try to insert it
if not self.tri.point_in_simplex(point, simplex):
return None, None
return self._add_pending_point_to_simplex(point, simplex)

def _update_subsimplex_losses(self, simplex, new_subsimplices):
loss = self._losses[simplex]

Expand Down
17 changes: 14 additions & 3 deletions adaptive/learner/triangulation_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@

import os

# Minimal version that is a complete drop-in for the learners
# (incl. ``get_opposing_vertices`` and pickle/deepcopy support).
_MIN_RUST_VERSION = (0, 2, 1)
# Minimal version that is a complete drop-in for the learners: includes the
# degenerate-simplex fix for curvature losses, plus the batched
# ``simplices_containing`` query and Rust ``default_loss`` that `LearnerND`
# uses when this backend is active.
_MIN_RUST_VERSION = (0, 3, 1)


def _rust_version() -> tuple[int, ...] | None:
Expand Down Expand Up @@ -119,6 +121,12 @@ def resolve_triangulation_class(backend="auto"):
point_in_simplex,
simplex_volume_in_embedding,
)

# The Rust implementation of `adaptive.learner.learnerND.default_loss`,
# which `LearnerND` prefers when no loss is given. Defined here (rather
# than re-exporting the Python one) to avoid a circular import with
# `learnerND`; ``None`` means "use the pure-Python default".
from adaptive_triangulation import default_loss as rust_default_loss
else:
from adaptive.learner.triangulation import (
Triangulation,
Expand All @@ -132,6 +140,8 @@ def resolve_triangulation_class(backend="auto"):
simplex_volume_in_embedding,
)

rust_default_loss = None

__all__ = [
"TRIANGULATION_BACKEND",
"Triangulation",
Expand All @@ -143,5 +153,6 @@ def resolve_triangulation_class(backend="auto"):
"fast_norm",
"orientation",
"point_in_simplex",
"rust_default_loss",
"simplex_volume_in_embedding",
]
37 changes: 37 additions & 0 deletions adaptive/tests/unit/test_triangulation_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,40 @@ def test_learnernd_uses_rust_backend():
learner.tell(point, learner.function(point))
assert isinstance(learner.tri, adaptive_triangulation.Triangulation)
assert learner.npoints >= 50


def test_rust_default_loss_matches_backend():
if backend.TRIANGULATION_BACKEND == "rust":
import adaptive_triangulation

assert backend.rust_default_loss is adaptive_triangulation.default_loss
else:
assert backend.rust_default_loss is None


def _ring_of_fire(xy):
import numpy as np

x, y = xy
a, d = 0.2, 0.5
return x + np.exp(-((x**2 + y**2 - d**2) ** 2) / a**4)


@pytest.mark.skipif(not rust_is_usable(), reason="needs adaptive-triangulation")
def test_rust_backend_samples_identical_points():
# The batched tell_pending path and the Rust default loss must not change
# which points the learner chooses.
from adaptive import LearnerND

learners = {
which: LearnerND(
_ring_of_fire, bounds=[(-1, 1), (-1, 1)], triangulation_backend=which
)
for which in ("python", "rust")
}
for learner in learners.values():
for _ in range(200):
points, _ = learner.ask(1)
for point in points:
learner.tell(point, learner.function(point))
assert sorted(learners["python"].data) == sorted(learners["rust"].data)
8 changes: 8 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def pytest_all_deps(session: nox.Session) -> None:
session.run("pytest", *xdist)


@nox.session(python="3.13")
def pytest_rust_backend(session: nox.Session) -> None:
"""Run the test suite with the Rust triangulation backend required."""
session.install(".[test,other,rust]")
session.run("coverage", "erase")
session.run("pytest", *xdist, env={"ADAPTIVE_TRIANGULATION_BACKEND": "rust"})


@nox.session(python="3.13")
def pytest_typeguard(session: nox.Session) -> None:
"""Run pytest with typeguard."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [

[project.optional-dependencies]
rust = [
"adaptive-triangulation>=0.2.1", # Rust-accelerated triangulation backend
"adaptive-triangulation>=0.3.1", # Rust-accelerated triangulation backend
]
other = [
"dill",
Expand Down
Loading