Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grid_integrator: Allow passing custom arguments to integrand function. #188

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions torchquad/integration/grid_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def f(integration_domain, N, requires_grad=False, backend=None):
def _weights(self, N, dim, backend, requires_grad=False):
return None

def integrate(self, fn, dim, N, integration_domain, backend):
def integrate(self, fn, dim, N, integration_domain, backend, args=None):
"""Integrate the passed function on the passed domain using a Composite Newton Cotes rule.
The argument meanings are explained in the sub-classes.

Expand All @@ -47,7 +47,7 @@ def integrate(self, fn, dim, N, integration_domain, backend):

logger.debug("Evaluating integrand on the grid.")
function_values, num_points = self.evaluate_integrand(
fn, grid_points, weights=self._weights(n_per_dim, dim, backend)
fn, grid_points, weights=self._weights(n_per_dim, dim, backend), args=args
)
self._nr_of_fevals = num_points

Expand Down Expand Up @@ -139,7 +139,7 @@ def _adjust_N(dim, N):
return N

def get_jit_compiled_integrate(
self, dim, N=None, integration_domain=None, backend=None
self, dim, N=None, integration_domain=None, backend=None, args=None
):
"""Create an integrate function where the performance-relevant steps except the integrand evaluation are JIT compiled.
Use this method only if the integrand cannot be compiled.
Expand All @@ -151,6 +151,7 @@ def get_jit_compiled_integrate(
N (int, optional): Total number of sample points to use for the integration. See the integrate method documentation for more details.
integration_domain (list or backend tensor, optional): Integration domain, e.g. [[-1,1],[0,1]]. Defaults to [-1,1]^dim. It can also determine the numerical backend.
backend (string, optional): Numerical backend. Defaults to integration_domain's backend if it is a tensor and otherwise to the backend from the latest call to set_up_backend or "torch" for backwards compatibility.
args (list or tuple, optional): Any arguments required by the function. Defaults to None.

Returns:
function(fn, integration_domain): JIT compiled integrate function where all parameters except the integrand and domain are fixed
Expand Down Expand Up @@ -197,7 +198,7 @@ def get_jit_compiled_integrate(
def compiled_integrate(fn, integration_domain):
grid_points, hs, n_per_dim = jit_calculate_grid(N, integration_domain)
function_values, _ = self.evaluate_integrand(
fn, grid_points, weights=self._weights(n_per_dim, dim, backend)
fn, grid_points, weights=self._weights(n_per_dim, dim, backend), args=args
)
return jit_calculate_result(
function_values, dim, int(n_per_dim), hs, integration_domain
Expand Down Expand Up @@ -238,6 +239,7 @@ def step3(function_values, hs, integration_domain):
example_integrand,
grid_points,
weights=self._weights(n_per_dim, dim, backend),
args=args,
)

# Trace the third step
Expand All @@ -257,7 +259,7 @@ def step3(function_values, hs, integration_domain):
def compiled_integrate(fn, integration_domain):
grid_points, hs, _ = step1(integration_domain)
function_values, _ = self.evaluate_integrand(
fn, grid_points, weights=self._weights(n_per_dim, dim, backend)
fn, grid_points, weights=self._weights(n_per_dim, dim, backend), args=args
)
result = step3(function_values, hs, integration_domain)
return result
Expand Down
Loading