diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py
index dcd0115..8cb3dc0 100644
--- a/jaxtyping/_array_types.py
+++ b/jaxtyping/_array_types.py
@@ -25,7 +25,16 @@
 import types
 import typing
 from dataclasses import dataclass
-from typing import Any, Literal, NoReturn, Optional, TypeVar, Union
+from typing import (
+    Any,
+    get_args,
+    get_origin,
+    Literal,
+    NoReturn,
+    Optional,
+    TypeVar,
+    Union,
+)
 
 
 # Bit of a hack, but jaxtyping provides nicer error messages than typeguard. This means
@@ -358,7 +367,7 @@ class for `Float32[Array, "foo"]`.
 
 _not_made = object()
 
-_union_types = [typing.Union]
+_union_types = [Union]
 if sys.version_info >= (3, 10):
     _union_types.append(types.UnionType)
 
@@ -517,6 +526,9 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
     # Allow Python built-in numeric types.
     # TODO: do something more generic than this? Should we _make all types
     # that have `shape` and `dtype` attributes or something?
+    array_origin = get_origin(array_type)
+    if array_origin is not None:
+        array_type = array_origin
     if array_type is bool:
         if _check_scalar("bool", dtypes, dims):
             return array_type
@@ -547,7 +559,7 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
             return array_type
         else:
             return _not_made
-    if issubclass(array_type, AbstractArray):
+    if array_type is not Any and issubclass(array_type, AbstractArray):
         if dtypes is _any_dtype:
             dtypes = array_type.dtypes
         elif array_type.dtypes is not _any_dtype:
@@ -588,11 +600,15 @@ def _make_array(*args, **kwargs):
 
     if type(out) is tuple:
         array_type, name, dtypes, dims, index_variadic, dim_str = out
-        metaclass = _make_metaclass(type(array_type))
+        metaclass = (
+            _make_metaclass(type)
+            if array_type is Any
+            else _make_metaclass(type(array_type))
+        )
 
         out = metaclass(
             name,
-            (array_type, AbstractArray),
+            (AbstractArray,) if array_type is Any else (array_type, AbstractArray),
             dict(
                 array_type=array_type,
                 dtypes=dtypes,
@@ -629,14 +645,18 @@ def __getitem__(cls, item: tuple[Any, str]):
         if isinstance(array_type, TypeVar):
             bound = array_type.__bound__
             if bound is None:
-                array_type = Any
+                constraints = array_type.__constraints__
+                if constraints == ():
+                    array_type = Any
+                else:
+                    array_type = Union[constraints]
             else:
                 array_type = bound
         del item
-        if typing.get_origin(array_type) in _union_types:
+        if get_origin(array_type) in _union_types:
             out = [
                 _make_array(x, dim_str, cls.dtypes, cls.__name__)
-                for x in typing.get_args(array_type)
+                for x in get_args(array_type)
             ]
             out = tuple(x for x in out if x is not _not_made)
             if len(out) == 0:
diff --git a/pyproject.toml b/pyproject.toml
index 746b48f..db44eac 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
 [project]
 name = "jaxtyping"
-version = "0.2.32"
+version = "0.2.33"
 description = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees."
 readme = "README.md"
 requires-python ="~=3.9"