Add get samples function to InstructionToSignals for JAX-jit usage#149
Conversation
|
Thanks for the PR @to24toro ! I think the code you've added to Dynamics looks good - however there is a core problem right now that prevents this from being useful. Your test: sets If you try to do this now with the current which is because in this case, With this error being raised, there is unfortunately no benefit to having To move forward with this PR, I think it makes sense to figure out whatever changes are necessary in terra to make the above test case work with the |
|
The special casing of amp must be removed from terra (only duration and amp are the non-float and unfortunately Python is not typed). Currently Qiskit/qiskit#9002 is trying to introduce (amp, float) representation to symbolic pulses, and this makes all parameters float-type except for duration. This approach allows us to remove builtin typecasting and we can eventually remove Feel free to on hold this until above PR is merged, or you can merge this PR without JIT test and later add the the test in a separate (follow up) PR. |
|
Okay cool thanks.
I think it makes sense to hold on this PR, as I don't think the required functionality can truly be verified until terra is at a point where the |
|
Fair enough. Let's merge Qiskit/qiskit#9002 first. |
| def jit_func(amp): | ||
| return get_samples(Constant(100, amp)) | ||
|
|
||
| jit_samples = jax.jit(jit_func)(0.1) |
There was a problem hiding this comment.
We expected this function will pass without using static_argnums after merging Qiskit/qiskit#9002 .
|
Hey @to24toro , there are two more tests I think it'd be good to add:
One last point: the |
ad0e3aa to
b82c152
Compare
| self.jit_wrap(jit_func)(0.1) | ||
| self.jit_grad_wrap(jit_func)(0.1) |
There was a problem hiding this comment.
I tried to use jit_wrap and jit_grad_wrap.
Very useful.
b82c152 to
f36cb66
Compare
| and (method == "jax_odeint" or _is_diffrax_method(method)) | ||
| and all(isinstance(x, Schedule) for x in signals_list) | ||
| # check if jit transformation is already performed. | ||
| and not (isinstance(jnp.array(0), core.Tracer)) |
|
I modified two points at 769e342 for JAX-jitting. |
|
I've removed the "on hold" label, as you pointed out that terra has been sufficiently updated. Once the errors are resolved and I re-review we can merge this! |
DanPuzzuoli
left a comment
There was a problem hiding this comment.
Aside from current test failing issue (which seems to just be a sympy issue), I have some comments about reorganizing the new tests a little.
Looks good though.
| self.assertTrue(signals[2].carrier_freq == 0.0) | ||
| self.assertTrue(signals[3].carrier_freq == 4.0) | ||
|
|
||
| def test_get_samples(self): |
There was a problem hiding this comment.
Rewrite these tests to still call the InstructionToSignals converter. get_samples and lru_cache_expr are both internal functions, so ideally we won't directly test them. There are exceptions to this in Dynamics, but in this case I don't think it's necessary. The inputs/outputs of these functions are directly fed from/to the converter, so I don't think anything is gained by directly testing these functions, as opposed to verifying the behaviour of the converter on a symbolic pulse.
Maybe change the name of this to test_SymbolicPulse.
There was a problem hiding this comment.
improved and changed the name at e2c5148
| ) | ||
|
|
||
|
|
||
| class TestJaxGetSamples(QiskitDynamicsTestCase, TestJaxBase): |
There was a problem hiding this comment.
Similar to the preceding comment, I don't think it's necessary to have a separate class for checking the behaviour of get_samples. These tests could be moved to the previous class, by rewriting them to verify that the converter returns the same thing, whether you pass in a Waveform or a SymbolicPulse.
The jit test is great and very important, but again, it can be moved to the previous class and rewritten to work with the converter.
There was a problem hiding this comment.
For this I mean to fully get rid of this test class, and to move the cases it's testing into the preceding test class. by calling InstructionToSignals directly.
There was a problem hiding this comment.
The class you said moving to is TestPulseToSignals?
If I move to the class, we fail to pass the test because TestPulseToSignals is not a subclass of TestJaxBase.
On the other hand, if TestPulseToSignals inherits TestJaxBase, all the test in TestPulseToSignals will be skipped in python test and executed in only JAX test.
There was a problem hiding this comment.
I understand now. Maybe just rename this class then to TestPulseToSignalsJAXTransformations or something along these lines.
3f214c5 to
03592c1
Compare
Co-authored-by: Daniel Puzzuoli <dan.puzzuoli@gmail.com>
03592c1 to
bf66412
Compare
|
My author and commit name were not correct. So I am sorry to have modify them and force-push to pass license/cla. |
DanPuzzuoli
left a comment
There was a problem hiding this comment.
There are some conflicts to resolve, and some final tests calling get_samples to convert to calls directly to InstructionToSignals.
|
|
||
| See the :meth:`get_signals` method documentation for a detailed description of how pulse | ||
| schedules are interpreted and translated into :class:`.DiscreteSignal` objects. | ||
| schedule. The converter applies to instances of :class:`Schedule`. Instances of |
There was a problem hiding this comment.
I think something's happened here - this is undoing some changes I had made in a previous PR (that I don't think you are intentionally trying to change).
There was a problem hiding this comment.
I am so sorry to confuse you. As you say, it is not my intention. I restored to that of main.
| dt: Length of the samples. This is required by the converter as pulse schedule are | ||
| specified in units of dt and typically do not carry the value of dt with them. | ||
| carriers: A dict of analog carrier frequencies. The keys are the names of the channels | ||
| dt: Length of the samples. This is required by the converter as pulse |
There was a problem hiding this comment.
These two code blocks are similarly changing some documentation I had changed previously. Maybe some issue with a previous attempt at merging main into this branch?
| :class:`~qiskit.pulse.ScheduleBlock` must first be converted to | ||
| :class:`~qiskit.pulse.Schedule` using the | ||
| :func:`~qiskit.pulse.transforms.block_to_schedule` function in Qiskit Pulse. | ||
| :class:`ScheduleBlock` must first be converted to :class:`Schedule` using the |
There was a problem hiding this comment.
Similarly here, this is undoing some changes I had made.
| ) | ||
|
|
||
|
|
||
| class TestJaxGetSamples(QiskitDynamicsTestCase, TestJaxBase): |
There was a problem hiding this comment.
For this I mean to fully get rid of this test class, and to move the cases it's testing into the preceding test class. by calling InstructionToSignals directly.
| envelope=envelope_expr, | ||
| valid_amp_conditions=valid_amp_conditions_expr, | ||
| ) | ||
| return get_samples(instance) |
There was a problem hiding this comment.
E.g. this test should use InstructionToSignals instead of get_samples directly.
| jit_samples = jax.jit(jit_func_get_samples)(0.1) | ||
| self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7) | ||
|
|
||
| def test_pulse_types_combination_with_jax(self): |
There was a problem hiding this comment.
This is a great test, that could be moved to the preceding class.
DanPuzzuoli
left a comment
There was a problem hiding this comment.
Looks good! One final documentation comment.
|
|
||
| The converter can be initialized with the optional arguments ``carriers`` and ``channels``. When | ||
| ``channels`` is given, only the signals specified by name in ``channels`` are returned. The | ||
| ``carriers`` dictionary specifies the analog carrier frequency of each channel. Here, the keys | ||
| are the channel name, e.g. ``d12`` for drive channel number ``12``, and the values are the | ||
| corresponding frequency. If a channel is not present in ``carriers`` it is assumed that the | ||
| analog carrier frequency is zero. | ||
|
|
There was a problem hiding this comment.
My last question is: did you intentionally remove these spaces? Deleting these causes this whole block of text to be a single paragraph in the docs. I prefer these as separate paragraphs, as each is highlighting something different. Unless there is a specific reason for this, can you put the spaces back in.
There was a problem hiding this comment.
Sorry. I didn't take into account the paragraph change. I returned back at c3338c5.
Summary
This PR adds
get_samplestoInstructionToSignalsfor JAX-jitting when using qiskit-pulse and removes the usage ofget_waveformmethod ofSymbolicPulse.Details and comments
get_samplesfunction gets the envelope expression formSymbolicPulseand calls sympy.lambdify with numerical backend specified by Array class. The lambdified function is lru cached for performance.