Skip to content

Commit

Permalink
Reduce greatly the peppering of object.Value() and return the actual …
Browse files Browse the repository at this point in the history
…type for REFERENCE instead of referenced object's type
  • Loading branch information
ldemailly committed Sep 3, 2024
1 parent ab6e16f commit b00cb0e
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 52 deletions.
89 changes: 45 additions & 44 deletions eval/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (s *State) evalIndexAssigment(which ast.Node, index, value object.Object) o
if index.Type() != object.INTEGER {
return s.NewError("index assignment to array with non integer index: " + index.Inspect())
}
idx := object.Value(index).(object.Integer).Value
idx := index.(object.Integer).Value
if idx < 0 {
idx = int64(object.Len(val)) + idx
}
Expand All @@ -72,7 +72,7 @@ func (s *State) evalIndexAssigment(which ast.Node, index, value object.Object) o
}
return value
case object.MAP:
m := object.Value(val).(object.Map)
m := val.(object.Map)
m = m.Set(index, value)
oerr := s.env.Set(id.Literal(), m)
if oerr.Type() == object.ERROR {
Expand Down Expand Up @@ -104,7 +104,7 @@ func (s *State) evalPrefixIncrDecr(operator token.Type, node ast.Node) object.Ob
log.LogVf("eval prefix %s", ast.DebugString(node))
nv := node.Value()
if nv.Type() != token.IDENT {
return s.NewError("can't increment/decrement " + nv.DebugString())
return s.NewError("can't prefix increment/decrement " + nv.DebugString())
}
id := nv.Literal()
val, ok := s.env.Get(id)
Expand All @@ -122,7 +122,7 @@ func (s *State) evalPrefixIncrDecr(operator token.Type, node ast.Node) object.Ob
case object.Float:
return s.env.Set(id, object.Float{Value: val.Value + float64(toAdd)}) // So PI++ fails not silently.
default:
return s.NewError("can't increment/decrement " + val.Type().String())
return s.NewError("can't prefix increment/decrement " + val.Type().String())
}
}

Expand All @@ -142,15 +142,15 @@ func (s *State) evalPostfixExpression(node *ast.PostfixExpression) object.Object
default:
return s.NewError("unknown postfix operator: " + node.Type().String())
}
var oerr object.Object
val = object.Value(val)
var oerr object.Object
switch val := val.(type) {
case object.Integer:
oerr = s.env.Set(id, object.Integer{Value: val.Value + toAdd})
case object.Float:
oerr = s.env.Set(id, object.Float{Value: val.Value + float64(toAdd)}) // So PI++ fails not silently.
default:
return s.NewError("can't increment/decrement " + val.Type().String())
return s.NewError("can't postfix increment/decrement " + val.Type().String())
}
if oerr.Type() == object.ERROR {
return oerr
Expand Down Expand Up @@ -196,7 +196,7 @@ func (s *State) evalInternal(node any) object.Object { //nolint:funlen,gocyclo /
log.LogVf("eval infix %s", node.DebugString())
// Eval and not evalInternal because we need to unwrap "return".
if node.Token.Type() == token.ASSIGN || node.Token.Type() == token.DEFINE {
return s.evalAssignment(s.Eval(node.Right), node)
return s.evalAssignment(s.evalInternal(node.Right), node)
}
// Humans expect left to right evaluations.
left := s.Eval(node.Left)
Expand Down Expand Up @@ -260,7 +260,7 @@ func (s *State) evalInternal(node any) object.Object { //nolint:funlen,gocyclo /
}
return fn
case *ast.CallExpression:
f := s.evalInternal(node.Function)
f := s.Eval(node.Function)
if f.Type() == object.ERROR {
return f
}
Expand All @@ -269,7 +269,7 @@ func (s *State) evalInternal(node any) object.Object { //nolint:funlen,gocyclo /
return *oerr
}
if f.Type() == object.EXTENSION {
return s.applyExtension(object.Value(f).(object.Extension), args)
return s.applyExtension(f.(object.Extension), args)
}
name := node.Function.Value().Literal()
return s.applyFunction(name, f, args)
Expand All @@ -282,7 +282,7 @@ func (s *State) evalInternal(node any) object.Object { //nolint:funlen,gocyclo /
case *ast.MapLiteral:
return s.evalMapLiteral(node)
case *ast.IndexExpression:
return s.evalIndexExpression(s.evalInternal(node.Left), node)
return s.evalIndexExpression(s.Eval(node.Left), node)
case *ast.Comment:
return object.NULL
}
Expand All @@ -306,7 +306,7 @@ func (s *State) evalIndexExpression(left object.Object, node *ast.IndexExpressio
rangeExp := node.Index.(*ast.InfixExpression)
return s.evalIndexRangeExpression(left, rangeExp.Left, rangeExp.Right)
}
index = s.evalInternal(node.Index)
index = s.Eval(node.Index)
if index.Type() == object.ERROR {
return index
}
Expand Down Expand Up @@ -345,7 +345,7 @@ func (s *State) evalPrintLogError(node *ast.Builtin) object.Object {
return r
}
if isString := r.Type() == object.STRING; isString {
buf.WriteString(object.Value(r).(object.String).Value)
buf.WriteString(r.(object.String).Value)
} else {
buf.WriteString(r.Inspect())
}
Expand Down Expand Up @@ -415,28 +415,28 @@ func (s *State) evalBuiltin(node *ast.Builtin) object.Object {
}

func (s *State) evalIndexRangeExpression(left object.Object, leftIdx, rightIdx ast.Node) object.Object {
leftIndex := s.evalInternal(leftIdx)
leftIndex := s.Eval(leftIdx)
nilRight := (rightIdx == nil)
var rightIndex object.Object
if nilRight {
log.Debugf("eval index %s[%s:]", left.Inspect(), leftIndex.Inspect())
} else {
rightIndex = s.evalInternal(rightIdx)
rightIndex = s.Eval(rightIdx)
log.Debugf("eval index %s[%s:%s]", left.Inspect(), leftIndex.Inspect(), rightIndex.Inspect())
}
if leftIndex.Type() != object.INTEGER || (!nilRight && rightIndex.Type() != object.INTEGER) {
return s.NewError("range index not integer")
}
num := object.Len(left)
l := object.Value(leftIndex).(object.Integer).Value
l := leftIndex.(object.Integer).Value
if l < 0 { // negative is relative to the end.
l = int64(num) + l
}
var r int64
if nilRight {
r = int64(num)
} else {
r = object.Value(rightIndex).(object.Integer).Value
r = rightIndex.(object.Integer).Value
if r < 0 {
r = int64(num) + r
}
Expand All @@ -448,7 +448,7 @@ func (s *State) evalIndexRangeExpression(left object.Object, leftIdx, rightIdx a
r = min(r, int64(num))
switch {
case left.Type() == object.STRING:
str := object.Value(left).(object.String).Value
str := left.(object.String).Value
return object.String{Value: str[l:r]}
case left.Type() == object.ARRAY:
return object.NewArray(object.Elements(left)[l:r])
Expand All @@ -468,8 +468,8 @@ func (s *State) evalIndexExpressionIdx(left, index object.Object) object.Object
}
switch {
case left.Type() == object.STRING && idxOrZero.Type() == object.INTEGER:
idx := object.Value(idxOrZero).(object.Integer).Value
str := object.Value(left).(object.String).Value
idx := idxOrZero.(object.Integer).Value
str := left.(object.String).Value
num := len(str)
if idx < 0 { // negative is relative to the end.
idx = int64(num) + idx
Expand All @@ -490,7 +490,7 @@ func (s *State) evalIndexExpressionIdx(left, index object.Object) object.Object
}

func evalMapIndexExpression(assoc, key object.Object) object.Object {
m := object.Value(assoc).(object.Map)
m := assoc.(object.Map)
v, ok := m.Get(key)
if !ok {
return object.NULL
Expand All @@ -499,7 +499,7 @@ func evalMapIndexExpression(assoc, key object.Object) object.Object {
}

func evalArrayIndexExpression(array, index object.Object) object.Object {
idx := object.Value(index).(object.Integer).Value
idx := index.(object.Integer).Value
maxV := int64(object.Len(array) - 1)
if idx < 0 { // negative is relative to the end.
idx = maxV + 1 + idx // elsewhere we use len() but here maxV is len-1
Expand Down Expand Up @@ -538,7 +538,7 @@ func (s *State) applyExtension(fn object.Extension, args []object.Object) object
}
// Auto promote integer to float if needed.
if fn.ArgTypes[i] == object.FLOAT && arg.Type() == object.INTEGER {
args[i] = object.Float{Value: float64(object.Value(arg).(object.Integer).Value)}
args[i] = object.Float{Value: float64(arg.(object.Integer).Value)}
continue
}
if fn.ArgTypes[i] != arg.Type() {
Expand All @@ -553,7 +553,7 @@ func (s *State) applyExtension(fn object.Extension, args []object.Object) object
}

func (s *State) applyFunction(name string, fn object.Object, args []object.Object) object.Object {
function, ok := object.Value(fn).(object.Function)
function, ok := fn.(object.Function)
if !ok {
return s.NewError("not a function: " + fn.Type().String() + ":" + fn.Inspect())
}
Expand Down Expand Up @@ -637,10 +637,11 @@ func extendFunctionEnv(
name, len(args), atLeast, n)}
}
for paramIdx, param := range params {
oerr := env.CreateOrSet(param.Value().Literal(), args[paramIdx], true)
// By definition function parameters are local copies, deref argument values:
oerr := env.CreateOrSet(param.Value().Literal(), object.Value(args[paramIdx]), true)
log.LogVf("set %s to %s - %s", param.Value().Literal(), args[paramIdx].Inspect(), oerr.Inspect())
if oerr.Type() == object.ERROR {
oe, _ := object.Value(oerr).(object.Error)
oe, _ := oerr.(object.Error)
return nil, &oe
}
}
Expand All @@ -662,7 +663,7 @@ func (s *State) evalExpressions(exps []ast.Node) ([]object.Object, *object.Error
for _, e := range exps {
evaluated := s.evalInternal(e)
if rt := evaluated.Type(); rt == object.ERROR {
oerr := object.Value(evaluated).(object.Error)
oerr := evaluated.(object.Error)
return nil, &oerr
}
result = append(result, evaluated)
Expand Down Expand Up @@ -761,14 +762,14 @@ func (s *State) evalForSpecialForms(fe *ast.ForExpression) (object.Object, bool)
if end.Type() != object.INTEGER {
return s.NewError("for var = n:m m not an integer: " + end.Inspect()), true
}
startInt := object.Value(start).(object.Integer)
return s.evalForInteger(fe, &startInt, object.Value(end).(object.Integer), name), true
startInt := start.(object.Integer)
return s.evalForInteger(fe, &startInt, end.(object.Integer), name), true
}
// Evaluate:
v := s.Eval(ie.Right)
v := s.evalInternal(ie.Right)
switch v.Type() {
case object.INTEGER:
return s.evalForInteger(fe, nil, object.Value(v).(object.Integer), name), true
return s.evalForInteger(fe, nil, v.(object.Integer), name), true
case object.ERROR:
return v, true
case object.ARRAY, object.MAP, object.STRING:
Expand Down Expand Up @@ -837,7 +838,7 @@ func (s *State) evalForExpression(fe *ast.ForExpression) object.Object {
case object.ERROR:
return condition
case object.INTEGER:
return s.evalForInteger(fe, nil, object.Value(condition).(object.Integer), "")
return s.evalForInteger(fe, nil, condition.(object.Integer), "")
default:
return s.NewError("for condition is not a boolean nor integer nor assignment: " + condition.Inspect())
}
Expand Down Expand Up @@ -878,7 +879,7 @@ func (s *State) evalPrefixExpression(operator token.Type, right object.Object) o
return s.evalMinusPrefixOperatorExpression(right)
case token.BITNOT, token.BITXOR:
if right.Type() == object.INTEGER {
return object.Integer{Value: ^object.Value(right).(object.Integer).Value}
return object.Integer{Value: ^right.(object.Integer).Value}
}
return s.NewError("bitwise not of " + right.Inspect())
case token.PLUS:
Expand All @@ -905,10 +906,10 @@ func (s *State) evalBangOperatorExpression(right object.Object) object.Object {
func (s *State) evalMinusPrefixOperatorExpression(right object.Object) object.Object {
switch right.Type() {
case object.INTEGER:
value := object.Value(right).(object.Integer).Value
value := right.(object.Integer).Value
return object.Integer{Value: -value}
case object.FLOAT:
value := object.Value(right).(object.Float).Value
value := right.(object.Float).Value
return object.Float{Value: -value}
default:
return s.NewError("minus of " + right.Inspect())
Expand Down Expand Up @@ -950,13 +951,13 @@ func (s *State) evalInfixExpression(operator token.Type, left, right object.Obje
}

func (s *State) evalStringInfixExpression(operator token.Type, left, right object.Object) object.Object {
leftVal := object.Value(left).(object.String).Value
leftVal := left.(object.String).Value
switch {
case operator == token.PLUS && right.Type() == object.STRING:
rightVal := object.Value(right).(object.String).Value
rightVal := right.(object.String).Value
return object.String{Value: leftVal + rightVal}
case operator == token.ASTERISK && right.Type() == object.INTEGER:
rightVal := object.Value(right).(object.Integer).Value
rightVal := right.(object.Integer).Value
n := len(leftVal) * int(rightVal)
if rightVal < 0 {
return s.NewError("right operand of * on strings must be a positive integer")
Expand All @@ -977,7 +978,7 @@ func (s *State) evalArrayInfixExpression(operator token.Type, left, right object
return s.NewError("right operand of * on arrays must be an integer")
}
// TODO: go1.23 use slices.Repeat
rightVal := object.Value(right).(object.Integer).Value
rightVal := right.(object.Integer).Value
if rightVal < 0 {
return s.NewError("right operand of * on arrays must be a positive integer")
}
Expand All @@ -1000,8 +1001,8 @@ func (s *State) evalArrayInfixExpression(operator token.Type, left, right object
}

func evalMapInfixExpression(operator token.Type, left, right object.Object) object.Object {
leftMap := object.Value(left).(object.Map)
rightMap := object.Value(right).(object.Map)
leftMap := left.(object.Map)
rightMap := right.(object.Map)
switch operator {
case token.PLUS: // concat / append
return leftMap.Append(rightMap)
Expand All @@ -1016,8 +1017,8 @@ func evalMapInfixExpression(operator token.Type, left, right object.Object) obje
// https://github.com/golang/go/issues/48522
// would need getters/setters which is not very go idiomatic.
func (s *State) evalIntegerInfixExpression(operator token.Type, left, right object.Object) object.Object {
leftVal := object.Value(left).(object.Integer).Value
rightVal := object.Value(right).(object.Integer).Value
leftVal := left.(object.Integer).Value
rightVal := right.(object.Integer).Value

switch operator {
case token.PLUS:
Expand Down Expand Up @@ -1058,9 +1059,9 @@ func (s *State) evalIntegerInfixExpression(operator token.Type, left, right obje
func (s *State) getFloatValue(o object.Object) (float64, *object.Error) {
switch o.Type() {
case object.INTEGER:
return float64(object.Value(o).(object.Integer).Value), nil
return float64(o.(object.Integer).Value), nil
case object.FLOAT:
return object.Value(o).(object.Float).Value, nil
return o.(object.Float).Value, nil
default:
e := s.NewError("not converting to float: " + o.Type().String())
return math.NaN(), &e
Expand Down
5 changes: 4 additions & 1 deletion eval/eval_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ func (s *State) Eval(node any) object.Object {
if returnValue.ControlType != token.RETURN {
return s.Errorf("unexpected control type %v outside of for loops", returnValue.ControlType)
}
return returnValue.Value
result = returnValue.Value
}
if refValue, ok := result.(object.Reference); ok {
return object.Value(refValue)
}
return result
}
Expand Down
4 changes: 2 additions & 2 deletions eval/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func testEval(t *testing.T, input string) object.Object {
}

func testIntegerObject(t *testing.T, obj object.Object, expected int64) bool {
result, ok := object.Value(obj).(object.Integer)
result, ok := obj.(object.Integer)
if !ok {
t.Errorf("object is not Integer. got=%T (%+v)", obj, obj)
return false
Expand Down Expand Up @@ -175,7 +175,7 @@ func TestBangOperator(t *testing.T) {
}

func testBooleanObject(t *testing.T, obj object.Object, expected bool) {
result, ok := object.Value(obj).(object.Boolean)
result, ok := obj.(object.Boolean)
if !ok {
t.Errorf("object is not Boolean. got=%T (%+v)", obj, obj)
return
Expand Down
8 changes: 4 additions & 4 deletions extensions/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func createStrFunctions() {
if a.Type() != object.STRING {
strs[i] = a.Inspect()
} else {
strs[i] = object.Value(a).(object.String).Value
strs[i] = a.(object.String).Value
}
totalLen += len(strs[i]) + sepLen
}
Expand Down Expand Up @@ -353,14 +353,14 @@ func createMisc() {
case object.NIL:
return object.Integer{Value: 0}
case object.BOOLEAN:
if object.Value(o).(object.Boolean).Value {
if o.(object.Boolean).Value {
return object.Integer{Value: 1}
}
return object.Integer{Value: 0}
case object.FLOAT:
return object.Integer{Value: int64(object.Value(o).(object.Float).Value)}
return object.Integer{Value: int64(o.(object.Float).Value)}
case object.STRING:
i, serr := strconv.ParseInt(object.Value(o).(object.String).Value, 0, 64)
i, serr := strconv.ParseInt(o.(object.String).Value, 0, 64)
if serr != nil {
return s.Error(serr)
}
Expand Down
2 changes: 1 addition & 1 deletion object/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ func (r Reference) Value() Object {
}

func (r Reference) Unwrap(str bool) any { return r.Value().Unwrap(str) }
func (r Reference) Type() Type { return r.Value().Type() }
func (r Reference) Type() Type { return REFERENCE }
func (r Reference) Inspect() string { return r.Value().Inspect() }
func (r Reference) JSON(w io.Writer) error { return r.Value().JSON(w) }

Expand Down

0 comments on commit b00cb0e

Please sign in to comment.