@@ -415,7 +415,126 @@ def forward(self, X: Tensor) -> Tensor:
415
415
return _log_ei_helper (u ) + sigma .log ()
416
416
417
417
418
- class LogConstrainedExpectedImprovement (AnalyticAcquisitionFunction ):
418
+ class ConstrainedAnalyticAcquisitionFunctionMixin :
419
+ r"""Base class for constrained analytic acquisition functions."""
420
+
421
+ def __init__ (
422
+ self ,
423
+ constraints : dict [int , tuple [float | None , float | None ]],
424
+ register_buffers : bool = False ,
425
+ ) -> None :
426
+ r"""Analytic Log Probability of Feasibility.
427
+
428
+ Args:
429
+ model: A fitted multi-output model.
430
+ constraints: A dictionary of the form `{i: [lower, upper]}`, where
431
+ `i` is the output index, and `lower` and `upper` are lower and upper
432
+ bounds on that output (resp. interpreted as -Inf / Inf if None).
433
+ register_buffers: If True, register the constraint bounds as PyTorch
434
+ buffers. Assumes that `self` is a derivative of the PyTorch's `Module`.
435
+ """
436
+ self .constraints = constraints
437
+ self ._preprocess_constraint_bounds (constraints = constraints )
438
+ if register_buffers :
439
+ self ._register_constraints_buffer ()
440
+
441
+ def _preprocess_constraint_bounds (
442
+ self ,
443
+ constraints : dict [int , tuple [float | None , float | None ]],
444
+ ) -> None :
445
+ r"""Set up constraint bounds.
446
+
447
+ Args:
448
+ constraints: A dictionary of the form `{i: [lower, upper]}`, where
449
+ `i` is the output index, and `lower` and `upper` are lower and upper
450
+ bounds on that output (resp. interpreted as -Inf / Inf if None)
451
+ """
452
+ con_lower , con_lower_inds = [], []
453
+ con_upper , con_upper_inds = [], []
454
+ con_both , con_both_inds = [], []
455
+ con_indices = list (constraints .keys ())
456
+ if len (con_indices ) == 0 :
457
+ raise ValueError ("There must be at least one constraint." )
458
+ # CEI, LogCEI have an objective index, but LogPOF does not.
459
+ if hasattr (self , "objective_index" ) and self .objective_index in con_indices :
460
+ raise ValueError (
461
+ "Output corresponding to objective should not be a constraint."
462
+ )
463
+ for k in con_indices :
464
+ if constraints [k ][0 ] is not None and constraints [k ][1 ] is not None :
465
+ if constraints [k ][1 ] <= constraints [k ][0 ]:
466
+ raise ValueError ("Upper bound is less than the lower bound." )
467
+ con_both_inds .append (k )
468
+ con_both .append ([constraints [k ][0 ], constraints [k ][1 ]])
469
+ elif constraints [k ][0 ] is not None :
470
+ con_lower_inds .append (k )
471
+ con_lower .append (constraints [k ][0 ])
472
+ elif constraints [k ][1 ] is not None :
473
+ con_upper_inds .append (k )
474
+ con_upper .append (constraints [k ][1 ])
475
+
476
+ for name , value in [
477
+ ("con_lower_inds" , con_lower_inds ),
478
+ ("con_upper_inds" , con_upper_inds ),
479
+ ("con_both_inds" , con_both_inds ),
480
+ ("con_both" , con_both ),
481
+ ("con_lower" , con_lower ),
482
+ ("con_upper" , con_upper ),
483
+ ]:
484
+ # tensor-based indexing is much faster than list-based advanced indexing
485
+ setattr (self , name , torch .as_tensor (value ))
486
+
487
+ def _register_constraints_buffer (self ) -> None :
488
+ """Converts the constraint fields to PyTorch buffers. Assumes that
489
+ `self` is a derivative of the Mixin and inherits from PyTorch's `Module`.
490
+ """
491
+ for name , value in [
492
+ ("con_lower_inds" , self .con_lower_inds ),
493
+ ("con_upper_inds" , self .con_upper_inds ),
494
+ ("con_both_inds" , self .con_both_inds ),
495
+ ("con_both" , self .con_both ),
496
+ ("con_lower" , self .con_lower ),
497
+ ("con_upper" , self .con_upper ),
498
+ ]:
499
+ delattr (self , name )
500
+ self .register_buffer (name , tensor = value )
501
+
502
+ def _compute_log_prob_feas (
503
+ self ,
504
+ means : Tensor ,
505
+ sigmas : Tensor ,
506
+ ) -> Tensor :
507
+ r"""Compute logarithm of the feasibility probability for each batch of X.
508
+
509
+ Args:
510
+ X: A `(b) x 1 x d`-dim Tensor of `(b)` t-batches of `d`-dim design
511
+ points each.
512
+ means: A `(b) x m`-dim Tensor of means.
513
+ sigmas: A `(b) x m`-dim Tensor of standard deviations.
514
+
515
+ Returns:
516
+ A `b`-dim tensor of log feasibility probabilities
517
+
518
+ Note: This function does case-work for upper bound, lower bound, and both-sided
519
+ bounds. Another way to do it would be to use 'inf' and -'inf' for the
520
+ one-sided bounds and use the logic for the both-sided case. But this
521
+ causes an issue with autograd since we get 0 * inf.
522
+ """
523
+ return compute_log_prob_feas_from_bounds (
524
+ con_lower_inds = self .con_lower_inds ,
525
+ con_upper_inds = self .con_upper_inds ,
526
+ con_both_inds = self .con_both_inds ,
527
+ con_lower = self .con_lower ,
528
+ con_upper = self .con_upper ,
529
+ con_both = self .con_both ,
530
+ means = means ,
531
+ sigmas = sigmas ,
532
+ )
533
+
534
+
535
+ class LogConstrainedExpectedImprovement (
536
+ AnalyticAcquisitionFunction , ConstrainedAnalyticAcquisitionFunctionMixin
537
+ ):
419
538
r"""Log Constrained Expected Improvement (feasibility-weighted).
420
539
421
540
Computes the logarithm of the analytic expected improvement for a Normal posterior
@@ -464,13 +583,14 @@ def __init__(
464
583
maximize: If True, consider the problem a maximization problem.
465
584
"""
466
585
# Use AcquisitionFunction constructor to avoid check for posterior transform.
467
- super ( AnalyticAcquisitionFunction , self ) .__init__ (model = model )
586
+ AcquisitionFunction .__init__ (self , model = model )
468
587
self .posterior_transform = None
469
588
self .maximize = maximize
470
589
self .objective_index = objective_index
471
- self .constraints = constraints
472
590
self .register_buffer ("best_f" , torch .as_tensor (best_f ))
473
- _preprocess_constraint_bounds (self , constraints = constraints )
591
+ ConstrainedAnalyticAcquisitionFunctionMixin .__init__ (
592
+ self , constraints , register_buffers = True
593
+ )
474
594
self .register_forward_pre_hook (convert_to_target_pre_hook )
475
595
476
596
@t_batch_mode_transform (expected_q = 1 )
@@ -490,11 +610,80 @@ def forward(self, X: Tensor) -> Tensor:
490
610
mean_obj , sigma_obj = means [..., ind ], sigmas [..., ind ]
491
611
u = _scaled_improvement (mean_obj , sigma_obj , self .best_f , self .maximize )
492
612
log_ei = _log_ei_helper (u ) + sigma_obj .log ()
493
- log_prob_feas = _compute_log_prob_feas (self , means = means , sigmas = sigmas )
613
+ log_prob_feas = self . _compute_log_prob_feas (means = means , sigmas = sigmas )
494
614
return log_ei + log_prob_feas
495
615
496
616
497
- class ConstrainedExpectedImprovement (AnalyticAcquisitionFunction ):
617
+ class LogProbabilityOfFeasibility (
618
+ AnalyticAcquisitionFunction , ConstrainedAnalyticAcquisitionFunctionMixin
619
+ ):
620
+ r"""Log Probability of Feasbility.
621
+
622
+ Computes the logarithm of the analytic probability of feasibility for a Normal
623
+ posterior distribution weighted by a probability of feasibility. The objective and
624
+ constraints are assumed to be independent and have Gaussian posterior
625
+ distributions. Only supports non-batch mode (i.e. `q=1`). The model should be
626
+ multi-outcome, with the index of the objective and constraints passed to
627
+ the constructor.
628
+
629
+ See [Ament2023logei]_ for details. Formally,
630
+
631
+ `LogPOF(x) = Sum_i log(P(y_i \in [lower_i, upper_i]))`,
632
+
633
+ where `y_i ~ constraint_i(x)` and `lower_i`, `upper_i` are the lower and
634
+ upper bounds for the i-th constraint, respectively.
635
+
636
+ Example:
637
+ # example where the 0th output has a non-negativity constraint and
638
+ # the 1st output is the objective
639
+ >>> model = SingleTaskGP(train_X, train_Y)
640
+ >>> constraints = {0: (0.0, None)}
641
+ >>> LogPOF = LogProbabilityOfFeasibility(model, 0.2, 1, constraints)
642
+ >>> cei = LogPOF(test_X)
643
+ """
644
+
645
+ _log : bool = True
646
+
647
+ def __init__ (
648
+ self ,
649
+ model : Model ,
650
+ constraints : dict [int , tuple [float | None , float | None ]],
651
+ ) -> None :
652
+ r"""Analytic Log Probability of Feasibility.
653
+
654
+ Args:
655
+ model: A fitted multi-output model.
656
+ constraints: A dictionary of the form `{i: [lower, upper]}`, where
657
+ `i` is the output index, and `lower` and `upper` are lower and upper
658
+ bounds on that output (resp. interpreted as -Inf / Inf if None)
659
+ """
660
+ # Use AcquisitionFunction constructor to avoid check for posterior transform.
661
+ AcquisitionFunction .__init__ (self , model = model )
662
+ self .posterior_transform = None
663
+ ConstrainedAnalyticAcquisitionFunctionMixin .__init__ (
664
+ self , constraints , register_buffers = True
665
+ )
666
+ self .register_forward_pre_hook (convert_to_target_pre_hook )
667
+
668
+ @t_batch_mode_transform (expected_q = 1 )
669
+ def forward (self , X : Tensor ) -> Tensor :
670
+ r"""Evaluate Constrained Log Probability of Feasibility on the candidate set X.
671
+
672
+ Args:
673
+ X: A `(b) x 1 x d`-dim Tensor of `(b)` t-batches of `d`-dim design
674
+ points each.
675
+
676
+ Returns:
677
+ A `(b)`-dim Tensor of Log Probability of Feasibility values at the given
678
+ design points `X`.
679
+ """
680
+ means , sigmas = self ._mean_and_sigma (X ) # (b) x 1 + (m = num constraints)
681
+ return self ._compute_log_prob_feas (means = means , sigmas = sigmas )
682
+
683
+
684
+ class ConstrainedExpectedImprovement (
685
+ AnalyticAcquisitionFunction , ConstrainedAnalyticAcquisitionFunctionMixin
686
+ ):
498
687
r"""Constrained Expected Improvement (feasibility-weighted).
499
688
500
689
Computes the analytic expected improvement for a Normal posterior
@@ -543,13 +732,14 @@ def __init__(
543
732
"""
544
733
legacy_ei_numerics_warning (legacy_name = type (self ).__name__ )
545
734
# Use AcquisitionFunction constructor to avoid check for posterior transform.
546
- super ( AnalyticAcquisitionFunction , self ) .__init__ (model = model )
735
+ AcquisitionFunction .__init__ (self , model = model )
547
736
self .posterior_transform = None
548
737
self .maximize = maximize
549
738
self .objective_index = objective_index
550
- self .constraints = constraints
551
739
self .register_buffer ("best_f" , torch .as_tensor (best_f ))
552
- _preprocess_constraint_bounds (self , constraints = constraints )
740
+ ConstrainedAnalyticAcquisitionFunctionMixin .__init__ (
741
+ self , constraints , register_buffers = True
742
+ )
553
743
self .register_forward_pre_hook (convert_to_target_pre_hook )
554
744
555
745
@t_batch_mode_transform (expected_q = 1 )
@@ -569,7 +759,7 @@ def forward(self, X: Tensor) -> Tensor:
569
759
mean_obj , sigma_obj = means [..., ind ], sigmas [..., ind ]
570
760
u = _scaled_improvement (mean_obj , sigma_obj , self .best_f , self .maximize )
571
761
ei = sigma_obj * _ei_helper (u )
572
- log_prob_feas = _compute_log_prob_feas (self , means = means , sigmas = sigmas )
762
+ log_prob_feas = self . _compute_log_prob_feas (means = means , sigmas = sigmas )
573
763
return ei .mul (log_prob_feas .exp ())
574
764
575
765
@@ -1131,82 +1321,3 @@ def _get_noiseless_fantasy_model(
1131
1321
fantasy_model .likelihood .noise_covar .noise = Yvar
1132
1322
1133
1323
return fantasy_model
1134
-
1135
-
1136
- def _preprocess_constraint_bounds (
1137
- acqf : LogConstrainedExpectedImprovement | ConstrainedExpectedImprovement ,
1138
- constraints : dict [int , tuple [float | None , float | None ]],
1139
- ) -> None :
1140
- r"""Set up constraint bounds.
1141
-
1142
- Args:
1143
- constraints: A dictionary of the form `{i: [lower, upper]}`, where
1144
- `i` is the output index, and `lower` and `upper` are lower and upper
1145
- bounds on that output (resp. interpreted as -Inf / Inf if None)
1146
- """
1147
- con_lower , con_lower_inds = [], []
1148
- con_upper , con_upper_inds = [], []
1149
- con_both , con_both_inds = [], []
1150
- con_indices = list (constraints .keys ())
1151
- if len (con_indices ) == 0 :
1152
- raise ValueError ("There must be at least one constraint." )
1153
- if acqf .objective_index in con_indices :
1154
- raise ValueError (
1155
- "Output corresponding to objective should not be a constraint."
1156
- )
1157
- for k in con_indices :
1158
- if constraints [k ][0 ] is not None and constraints [k ][1 ] is not None :
1159
- if constraints [k ][1 ] <= constraints [k ][0 ]:
1160
- raise ValueError ("Upper bound is less than the lower bound." )
1161
- con_both_inds .append (k )
1162
- con_both .append ([constraints [k ][0 ], constraints [k ][1 ]])
1163
- elif constraints [k ][0 ] is not None :
1164
- con_lower_inds .append (k )
1165
- con_lower .append (constraints [k ][0 ])
1166
- elif constraints [k ][1 ] is not None :
1167
- con_upper_inds .append (k )
1168
- con_upper .append (constraints [k ][1 ])
1169
- # tensor-based indexing is much faster than list-based advanced indexing
1170
- for name , indices in [
1171
- ("con_lower_inds" , con_lower_inds ),
1172
- ("con_upper_inds" , con_upper_inds ),
1173
- ("con_both_inds" , con_both_inds ),
1174
- ("con_both" , con_both ),
1175
- ("con_lower" , con_lower ),
1176
- ("con_upper" , con_upper ),
1177
- ]:
1178
- acqf .register_buffer (name , tensor = torch .as_tensor (indices ))
1179
-
1180
-
1181
- def _compute_log_prob_feas (
1182
- acqf : LogConstrainedExpectedImprovement | ConstrainedExpectedImprovement ,
1183
- means : Tensor ,
1184
- sigmas : Tensor ,
1185
- ) -> Tensor :
1186
- r"""Compute logarithm of the feasibility probability for each batch of X.
1187
-
1188
- Args:
1189
- X: A `(b) x 1 x d`-dim Tensor of `(b)` t-batches of `d`-dim design
1190
- points each.
1191
- means: A `(b) x m`-dim Tensor of means.
1192
- sigmas: A `(b) x m`-dim Tensor of standard deviations.
1193
- Returns:
1194
- A `b`-dim tensor of log feasibility probabilities
1195
-
1196
- Note: This function does case-work for upper bound, lower bound, and both-sided
1197
- bounds. Another way to do it would be to use 'inf' and -'inf' for the
1198
- one-sided bounds and use the logic for the both-sided case. But this
1199
- causes an issue with autograd since we get 0 * inf.
1200
- TODO: Investigate further.
1201
- """
1202
- acqf .to (device = means .device )
1203
- return compute_log_prob_feas_from_bounds (
1204
- acqf .con_lower_inds ,
1205
- acqf .con_upper_inds ,
1206
- acqf .con_both_inds ,
1207
- acqf .con_lower ,
1208
- acqf .con_upper ,
1209
- acqf .con_both ,
1210
- means ,
1211
- sigmas ,
1212
- )
0 commit comments