multi condtions/switch with predicate functions #14264
Answered
by
jakevdp
jecampagne
asked this question in
Q&A
-
Hello def f(x):
if (x < 0):
x = -x
if (x == 0) :
return 1;
if (x <= 4): # x in (0, 4]
value = ...
elif (x <= 8.0): # x in (4, 8]
value = ...
else :
value = ...
return value Thanks |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Feb 2, 2023
Replies: 1 comment 2 replies
-
One way to do this is using import jax.numpy as jnp
import jax
@jax.jit
def f(x):
x = abs(x)
return jnp.select(
[x == 0, x <= 4, x <= 8],
[x, x - 1, x - 2],
default = x)
print(f(0))
# 0
print(f(-2))
# 1
print(f(7))
# 5
print(f(10))
# 10 |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
jecampagne
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
One way to do this is using
jnp.select
, which runs through a list of conditions and returns the associated value. For example: