Skip to content

Commit 295d2dd

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
add support for pairwise probit gpu (facebookresearch#625)
Summary: Pull Request resolved: facebookresearch#625 Allow pairwise probit model to be used with gpu Reviewed By: crasanders Differential Revision: D69160746 fbshipit-source-id: 0ecb47776dfc3eb3b8a202d4ca5ec2bd51fbae88
1 parent a2b071e commit 295d2dd

File tree

2 files changed

+105
-15
lines changed

2 files changed

+105
-15
lines changed

Diff for: aepsych/models/pairwise_probit.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from aepsych.config import Config
1313
from aepsych.factory import default_mean_covar_factory
14-
from aepsych.models.base import AEPsychMixin
14+
from aepsych.models.base import AEPsychModelDeviceMixin
1515
from aepsych.utils import _process_bounds, get_dims, get_optimizer_options, promote_0d
1616
from aepsych.utils_logging import getLogger
1717
from botorch.fit import fit_gpytorch_mll
@@ -22,7 +22,7 @@
2222
logger = getLogger()
2323

2424

25-
class PairwiseProbitModel(PairwiseGP, AEPsychMixin):
25+
class PairwiseProbitModel(PairwiseGP, AEPsychModelDeviceMixin):
2626
_num_outputs = 1
2727
stimuli_per_trial = 2
2828
outcome_type = "binary"
@@ -63,7 +63,10 @@ def _get_index_of_equal_row(arr, x, dim=0):
6363
comparisons.append(comparison)
6464
else:
6565
comparisons.append(comparison[::-1])
66-
return unique_coords.T, torch.LongTensor(comparisons)
66+
67+
datapoints = unique_coords.T.to(self.device)
68+
comps = torch.LongTensor(comparisons).to(self.device)
69+
return datapoints, comps
6770

6871
def __init__(
6972
self,
@@ -165,18 +168,6 @@ def fit(
165168
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs, **kwargs)
166169
logger.info(f"Fit done, time={time.time() - starttime}")
167170

168-
def update(
169-
self, train_x: torch.Tensor, train_y: torch.Tensor, warmstart: bool = True
170-
) -> None:
171-
"""Perform a warm-start update of the model from previous fit.
172-
173-
Args:
174-
train_x (torch.Tensor): Train X.
175-
train_y (torch.Tensor): Train Y.
176-
warmstart (bool): If True, warm-start model fitting with current parameters. Defaults to True.
177-
"""
178-
self.fit(train_x, train_y)
179-
180171
def predict(
181172
self,
182173
x: torch.Tensor,

Diff for: tests_gpu/models/test_pairwise_probit.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and its affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
10+
import numpy as np
11+
import torch
12+
from aepsych.benchmark.test_functions import f_1d, f_pairwise
13+
from aepsych.generators import OptimizeAcqfGenerator, SobolGenerator
14+
from aepsych.models import PairwiseProbitModel
15+
from aepsych.strategy import SequentialStrategy, Strategy
16+
from aepsych.transforms import (
17+
ParameterTransformedGenerator,
18+
ParameterTransformedModel,
19+
ParameterTransforms,
20+
)
21+
from aepsych.transforms.ops import NormalizeScale
22+
from botorch.acquisition import qUpperConfidenceBound
23+
from scipy.stats import bernoulli
24+
25+
26+
class PairwiseProbitModelStrategyTest(unittest.TestCase):
27+
def test_1d_pairwise_probit(self):
28+
"""
29+
test our 1d gaussian bump example
30+
"""
31+
seed = 1
32+
torch.manual_seed(seed)
33+
np.random.seed(seed)
34+
n_init = 50
35+
n_opt = 1
36+
lb = torch.tensor([-4.0])
37+
ub = torch.tensor([4.0])
38+
extra_acqf_args = {"beta": 3.84}
39+
transforms = ParameterTransforms(
40+
normalize=NormalizeScale(d=1, bounds=torch.stack([lb, ub]))
41+
)
42+
sobol_gen = ParameterTransformedGenerator(
43+
generator=SobolGenerator,
44+
lb=lb,
45+
ub=ub,
46+
seed=seed,
47+
stimuli_per_trial=2,
48+
transforms=transforms,
49+
)
50+
acqf_gen = ParameterTransformedGenerator(
51+
generator=OptimizeAcqfGenerator,
52+
acqf=qUpperConfidenceBound,
53+
acqf_kwargs=extra_acqf_args,
54+
stimuli_per_trial=2,
55+
transforms=transforms,
56+
lb=lb,
57+
ub=ub,
58+
)
59+
probit_model = ParameterTransformedModel(
60+
model=PairwiseProbitModel, lb=lb, ub=ub, transforms=transforms
61+
).to("cuda")
62+
model_list = [
63+
Strategy(
64+
lb=lb,
65+
ub=ub,
66+
generator=sobol_gen,
67+
min_asks=n_init,
68+
stimuli_per_trial=2,
69+
outcome_types=["binary"],
70+
transforms=transforms,
71+
),
72+
Strategy(
73+
lb=lb,
74+
ub=ub,
75+
model=probit_model,
76+
generator=acqf_gen,
77+
min_asks=n_opt,
78+
stimuli_per_trial=2,
79+
outcome_types=["binary"],
80+
transforms=transforms,
81+
use_gpu_generating=True,
82+
use_gpu_modeling=True,
83+
),
84+
]
85+
86+
strat = SequentialStrategy(model_list)
87+
88+
for _i in range(n_init + n_opt):
89+
next_pair = strat.gen().cpu()
90+
strat.add_data(
91+
next_pair, [bernoulli.rvs(f_pairwise(f_1d, next_pair, noise_scale=0.1))]
92+
)
93+
94+
x = torch.linspace(-4, 4, 100)
95+
96+
zhat, _ = strat.predict(x)
97+
98+
self.assertTrue(np.abs(x[np.argmax(zhat.cpu().detach().numpy())]) < 0.5)
99+
self.assertTrue(strat.model.device.type == "cuda")

0 commit comments

Comments
 (0)