Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RASP validator fails for some programs #11

Open
langosco opened this issue Nov 20, 2023 · 5 comments
Open

RASP validator fails for some programs #11

langosco opened this issue Nov 20, 2023 · 5 comments

Comments

@langosco
Copy link
Contributor

langosco commented Nov 20, 2023

Issue #9 introduces a validator to check RASP programs that compile incorrectly.
Here's one case---a RASP program that computes the sum of all inputs up to the current index---in which I think the validator fails (or I've misunderstood how it works):

from tracr.rasp import rasp
from tracr.compiler import validating, compiling


def sum_of_inputs() -> rasp.SOp:
    before = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    means = rasp.Aggregate(before, rasp.tokens)  # returns sequence s_i = mean_{j<=i} input_j
    sums = rasp.SequenceMap(lambda x, y: x*y, means, rasp.indices+1)
    return sums


sums = sum_of_inputs()

# The output of the RASP program sums is different that the output of the compiled model:
rasp_output = sums([3, 2, 1, 1])
compiled_model = compiling.compile_rasp_to_model(sums, vocab={1,2,3}, max_seq_len=5, compiler_bos="BOS")
compiled_output = compiled_model.apply(["BOS", 3, 2, 1, 1]).decoded

print(rasp_output)  # output: [3.0, 5.0, 6.0, 7.0]
print(compiled_output)  # output: ['BOS', 3, 4, 3, 4]

# However, it looks like the validator doesn't catch the error:
print(validating.validate(sums, [1, 2, 3]))  # returns an empty list
@david-lindner
Copy link
Collaborator

Thanks! Fixed by 001bdb3 -- but, feel free to reopen if you find other cases the validator doesn't catch

@langosco
Copy link
Contributor Author

langosco commented Jan 16, 2024

Came across another case that the validator doesn't catch:

from tracr.rasp import rasp
from tracr.compiler import compiling, validating

sel = rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.EQ)
sop = rasp.Aggregate(sel, rasp.indices)
program = rasp.Aggregate(sel, sop)


model = compiling.compile_rasp_to_model(program, vocab={1,2,3,4}, max_seq_len=5, compiler_bos="BOS")
compiled_output = model.apply(["BOS", 1, 2, 3, 4]).decoded
rasp_output = program([1, 2, 3, 4])


# The output of the compiled model does not match the output of the RASP program:
print(rasp_output)  # [2.0, 3.0, None, None]
print(compiled_output) # ['BOS', 2, 3, 0, 1]

# The validator doesn't catch the error:
print(validating.validate(program, [1, 2, 3, 4])) # []

@langosco
Copy link
Contributor Author

Also seems worth linking the two other cases documented in pull requests #13 #14

@david-lindner david-lindner reopened this Jan 16, 2024
@david-lindner
Copy link
Collaborator

For all of these cases, can you try increasing the mlp_exactness parameter, ie. add mlp_exactness=100 to the call to compile? I suspect that at least for #14 the issue is an approximation error in the MLP layer of the selector width

@langosco
Copy link
Contributor Author

You're right, looks like that fixes #14! mlp_exactness=100 is the default already, but #14 compiles fine when using mlp_exactness=120.

It doesn't seem to fix the other cases unfortunately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants