Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@
from typing import Iterator, Tuple

import torch
from packaging import version
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler, T_co


if version.parse(torch.__version__) >= version.parse("2.5.0"):
from torch.utils.data.distributed import DistributedSampler, _T_co
else:
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.distributed import T_co as _T_co

from lighteval.tasks.requests import (
GreedyUntilRequest,
Expand Down Expand Up @@ -318,7 +325,7 @@ class GenDistributedSampler(DistributedSampler):
as our samples are sorted by length.
"""

def __iter__(self) -> Iterator[T_co]:
def __iter__(self) -> Iterator[_T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
Expand Down
Loading