Gauging community interest for a new validation library for jax. #19863
Unanswered
smithblack-0
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Introduction
Hello All.
I have been developing a verification architecture and corresponding library that is designed to make validation much easier in the long run in jax.
I got started on this whole journey partly because I was frustrated with how primitive the jax validation mechanisms are for value-based cases (jax.jit, WHY do you not automatically stage checkify!!!?) and partly as an exercise in how to build a really solid validation system built specifically for tensors that is production quality.
While I admit at first I thought I would likely be keeping it to myself, the end product is proving to be sophisticated enough I would like to see if the broader community might be interested in incorporating it in some manner into jax itself, or some other project, if it looks useful enough. I could also publish it as a standalone project as well.
What is the architecture
Currently, I have an elegant way to reuse validation operands and support for most professional
error situations that I can think of.
Basically, you define a validation operator using the chain-of-responsibility pattern, so that each node in the
chain can be responsible for validating one thing, or maybehandling an exception from further down the chain. Then, you link them together. This is done using magic methods and a straightforward architecture like below:
Once invoked, errors are not actually raised, but passed back up the chain of responsibility. This means you
can do some clever things, like changing the behavior to log or throw, just by swapping out the head of the chain. Or
even just define a custom terminal node to integrate into your existing bespoke framework.
Under the hood, the library sets up a linked list that is transparently maintained without subclasses having to care about it. Everything is built to operate in a functional manner, meaning you will not break your existing linked lists reusing them. Additionally, to handle the performance hits this will cause when I start to handle pytrees, all constructor calls are cached and trying to construct a chain with the same arguments will reuse the existing class.
Finally, if I am not misreading the situation, it can also assist significantly with jit compatibility. The traditional issue is that
you cannot raise errors on the values of tensors, as it creates divergent jit paths. We can fix that. If you mark your error messages as, say error._jit_incompatible, you can catch the error as it is going up and choose to, say, print it to the console using jax.debug.print. That would safely discharge the side effects and allow jax to continue. I am not entirely sure yet if it will allow you
to completely dump checkify when operated in such a mode, but it is possible.
What is working, and what is planned?
Working
Currently, the core validator is working and mostly tested, although I have not done caching much before and need to figure out what tests to write for the hash function.
Planned: Operator library
This means simple behavior, such as validating a singular
operand against many different cases, is supported. Also supported
is reusing previously defined Validator chains in any different context.
Everything is functional, so do not worry about side effects.
Not implemented yet is the basic operators library. I could use some feedback
here
Pytrees
Pytree support is a planned core part of the libraries
The whole reason I am working in jax rather than torch or tensorflow, in fact, is pytrees.
So I have been somewhat frustrated that pytrees are so tricky to validate. You
basically have to hunt down and understand jax.tree_util.tree_map, and if you
want to validate that two trees are the SAME it gets a whole lot harder.
I do not like that. Instead, I plan to create a class called Schema
which can be made per batch or statically, and which itself contains
a pytree. You will be able to specify validator chains to run
for each leaf of the pytree. You will also be able to use it to,
for example, apply one validator chain across an entire
pytree in "prefix tree broadcasting."
Thankfully, I recently discovered that _src.tree_util under the
source code has utilities for dong this already, and even for raising good error messages,
so I will be importing a lot of the work from there. Since they are unit tested,
I think it should be fine. It also should be fairly fast, hopefully, as I
will be leaning on the optimized jax code.
Can I have?
At the moment, the library is not shared on mypy, but the repository is located at
the following link if you want to poke around. Please note it is fairly unpolished
at the moment, but feel free to clone it if you want.
https://github.com/smithblack-0/jax_validation
An example?
Lets consider the following spec
Lets see how the current library would implement that. This code is entirely
functional at the moment, and since the operators library is not in place
yet we have to define each validation operation ourself. That is not a bad thing,
though, as it shows the three main methods that can be overridden to influence
behavior
The following code runs, and appears to work, with the current library.
Beta Was this translation helpful? Give feedback.
All reactions