20
20
from .. import settings as s
21
21
from .. import utilities as u
22
22
from .. import interface as intf
23
+ from ..expressions .constants import Constant
23
24
from ..expressions .variables import Variable
24
25
from ..expressions .expression import Expression
25
- from ..expressions .affine import AffExpression
26
- from ..constraints .affine import AffEqConstraint , AffLeqConstraint
27
- from constant_atom import ConstantAtom
28
26
import abc
29
27
30
28
class Atom (Expression ):
@@ -40,24 +38,22 @@ def __init__(self, *args):
40
38
# Convert raw values to Constants.
41
39
self .args = map (Expression .cast_to_const , args )
42
40
self .subexpressions = self .args
43
- # Initialize context.
44
- self .set_context ()
45
41
super (Atom , self ).__init__ ()
46
42
47
43
# Returns the string representation of the function call.
48
44
def name (self ):
49
45
return "%s(%s)" % (self .__class__ .__name__ ,
50
46
", " .join ([arg .name () for arg in self .args ]))
51
47
52
- # Sets signed curvature based on the arguments' signed curvatures .
53
- def set_context (self ):
48
+ # Determines the curvature, sign, and shape from the arguments .
49
+ def _dcp_attr (self ):
54
50
# Initialize _shape. Raises an error for invalid argument sizes.
55
- self .set_shape ()
51
+ shape = self .shape_from_args ()
56
52
sign = self .sign_from_args ()
57
53
curvature = Atom .dcp_curvature (self .base_curvature (),
58
54
self .args ,
59
55
self .monotonicity ())
60
- self ._context = u .Context (sign , curvature , self ._shape )
56
+ self ._context = u .DCPAttr (sign , curvature , self ._shape )
61
57
62
58
# Returns argument curvatures as a list.
63
59
def argument_curvatures (self ):
@@ -91,30 +87,24 @@ def dcp_curvature(curvature, args, monotonicities):
91
87
def canonicalize (self ):
92
88
# Constant atoms are treated as a leaf.
93
89
if self .curvature .is_constant ():
94
- obj = AffExpression ({s .CONSTANT : self }, self .shape )
95
- return (obj , [])
96
- # Non-constant atoms are expanded into an affine objective and constraints.
90
+ return Constant (self .value ).canonicalize ()
97
91
else :
98
- var_args = []
99
- final_constraints = []
92
+ arg_objs = []
93
+ constraints = []
100
94
for arg in self .args :
101
- # canonicalize arguments.
102
- obj ,constraints = arg .canonical_form ()
103
- var_args .append (obj )
104
- final_constraints += constraints
105
- graph_var ,graph_constr = self .graph_implementation (var_args , self .size )
106
- obj = u .Affine .cast_as_affine (graph_var )
107
- return (obj ,final_constraints + graph_constr )
95
+ obj ,constr = arg .canonicalize ()
96
+ arg_objs .append (obj )
97
+ constraints += constr
98
+ graph_obj ,graph_constr = self .graph_implementation (arg_objs )
99
+ return (graph_obj , constraints + graph_constr )
108
100
109
- # Returns a variable and set of affine/SOC
101
+ # Returns an affine expression and list of
110
102
# constraints equivalent to the atom.
111
- # var_args - a list of single variable arguments.
112
- # size - the dimensions of the variable to return.
103
+ # arg_objs - the canonical objectives of the arguments.
113
104
@abc .abstractmethod
114
- def graph_implementation (var_args , size ):
105
+ def graph_implementation (self , arg_objs ):
115
106
return NotImplemented
116
107
117
-
118
108
# Wraps an atom's numeric function that requires numpy ndarrays as input.
119
109
# Ensures both inputs and outputs are the correct matrix types.
120
110
@staticmethod
0 commit comments