Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ dependencies = [
"tqdm>=4.62",
"matplotlib>=3.3",
"scikit-learn>=1.0",
"transformers>=4.25",
"transformers>=4.46",
"huggingface-hub>=0.16",
"datasets>=2.0",
"peft>=0.4",
Expand Down
136 changes: 117 additions & 19 deletions src/opentslm/model/llm/OpenTSLMFlamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
)
from open_flamingo.src.flamingo_lm import FlamingoLMMixin
from open_flamingo.src.utils import extend_instance
import random
import torch
import torch._dynamo
from typing import List, Dict, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Iterator, List, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer

from opentslm.model_config import ENCODER_OUTPUT_DIM
from opentslm.model.llm.TimeSeriesLLM import TimeSeriesLLM
Expand Down Expand Up @@ -108,12 +109,16 @@ def _infer_decoder_layers_attr_name(model):
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]

raise ValueError(
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
)

decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))
# Avoid Transformers' mean-based resize path, which can trip on meta tensors
# during low-memory model initialization in production worker environments.
lang_encoder.resize_token_embeddings(
len(text_tokenizer), mean_resizing=False
)
Comment thread
masquare marked this conversation as resolved.

# Fix compatibility for Gemma3Config which has hidden_size in text_config
if hasattr(lang_encoder.config, "text_config") and hasattr(
Expand Down Expand Up @@ -152,6 +157,39 @@ def _infer_decoder_layers_attr_name(model):
self.llm = model
self.text_tokenizer = text_tokenizer

def _build_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.lang_encoder.get_input_embeddings()(input_ids)

def _forward_with_embeddings(
self,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor | None = None,
):
inputs_embeds = self._build_input_embeddings(input_ids)
try:
self.model._encode_vision_x(vision_x=images)
self._condition_media_locations(input_ids)
return super(FlamingoLMMixin, self.model.lang_encoder).forward(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
)
finally:
self.model.lang_encoder.clear_conditioned_layers()

def _condition_media_locations(self, input_ids: torch.Tensor):
media_locations = input_ids == self.model.media_token_id
attend_previous = (
(random.random() < 0.5)
if self.model.lang_encoder.use_media_placement_augmentation
else False
)
for layer in self.model.lang_encoder._get_decoder_layers():
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def pad_and_apply_batch(
self, batch: List[Dict[str, any]], include_labels: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -255,16 +293,20 @@ def generate(
input_ids, images, attention_mask, _ = self.pad_and_apply_batch(
batch, include_labels=True
)

gen_ids = self.llm.generate(
vision_x=images,
lang_x=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
eos_token_id=self.text_tokenizer.eos_token_id,
pad_token_id=self.text_tokenizer.pad_token_id,
**generate_kwargs,
)
inputs_embeds = self._build_input_embeddings(input_ids)
try:
self.model._encode_vision_x(vision_x=images)
self._condition_media_locations(input_ids)
gen_ids = self.model.lang_encoder.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
eos_token_id=self.text_tokenizer.eos_token_id,
pad_token_id=self.text_tokenizer.pad_token_id,
**generate_kwargs,
)
finally:
self.model.lang_encoder.clear_conditioned_layers()

# Remove input ids from generation
answer_only_ids = gen_ids[:, input_ids.shape[1] :]
Comment thread
masquare marked this conversation as resolved.
Expand All @@ -276,6 +318,46 @@ def generate(
# Restore original compilation setting
torch._dynamo.config.disable = original_disable

def stream_generate(
self, batch: List[Dict[str, any]], max_new_tokens: int = 50, **generate_kwargs
) -> Iterator[str]:
self._validate_streaming_batch(batch)

original_disable = torch._dynamo.config.disable
torch._dynamo.config.disable = True
try:
with torch.inference_mode():
input_ids, images, attention_mask, _ = self.pad_and_apply_batch(
batch, include_labels=True
)
inputs_embeds = self._build_input_embeddings(input_ids)
streamer = TextIteratorStreamer(
self.text_tokenizer,
skip_prompt=True,
timeout=generate_kwargs.pop("stream_timeout", None),
)

def run_generation() -> None:
with torch.inference_mode():
try:
self.model._encode_vision_x(vision_x=images)
self._condition_media_locations(input_ids)
self.model.lang_encoder.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
eos_token_id=self.text_tokenizer.eos_token_id,
pad_token_id=self.text_tokenizer.pad_token_id,
streamer=streamer,
**generate_kwargs,
)
finally:
self.model.lang_encoder.clear_conditioned_layers()

yield from self._iterate_streamer(streamer, run_generation)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
finally:
torch._dynamo.config.disable = original_disable
Comment thread
masquare marked this conversation as resolved.

def compute_loss(self, batch: List[Dict[str, any]]) -> torch.Tensor:
"""
batch: same format as generate()
Expand All @@ -285,9 +367,9 @@ def compute_loss(self, batch: List[Dict[str, any]]) -> torch.Tensor:
batch, include_labels=False
)

output = self.model(
vision_x=images,
lang_x=input_ids,
output = self._forward_with_embeddings(
images=images,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
Expand Down Expand Up @@ -333,13 +415,13 @@ def load_from_file(self, path: str = "best_model.pt"):
# Load state dict with strict=False to handle missing/unexpected keys
missing_keys, unexpected_keys = self.load_state_dict(model_state, strict=False)
if missing_keys:
print(f"⚠️ Warning: Missing keys when loading checkpoint:")
print("⚠️ Warning: Missing keys when loading checkpoint:")
for key in missing_keys[:10]:
print(f" - {key}")
if len(missing_keys) > 10:
print(f" ... and {len(missing_keys) - 10} more keys")
if unexpected_keys:
print(f"⚠️ Warning: Unexpected keys when loading checkpoint:")
print("⚠️ Warning: Unexpected keys when loading checkpoint:")
for key in unexpected_keys[:10]:
print(f" - {key}")
if len(unexpected_keys) > 10:
Expand Down Expand Up @@ -368,3 +450,19 @@ def eval_prompt(
finally:
# Restore original compilation setting
torch._dynamo.config.disable = original_disable

def stream_prompt(
self,
prompt: FullPrompt,
max_new_tokens: int = 1000,
normalize: bool = False,
**generate_kwargs,
) -> Iterator[str]:
batch = [prompt.to_dict()]
self.eval()
batch = extend_time_series_to_match_patch_size_and_aggregate(
batch, normalize=normalize
)
yield from self.stream_generate(
batch, max_new_tokens=max_new_tokens, **generate_kwargs
)
59 changes: 47 additions & 12 deletions src/opentslm/model/llm/OpenTSLMSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
# SPDX-License-Identifier: MIT

import torch
import torch.nn as nn
from typing import List, Dict, Tuple, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Any, Dict, Iterator, List, Optional, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
Comment thread
coderabbitai[bot] marked this conversation as resolved.
from torch.nn.utils.rnn import pad_sequence

try:
Expand Down Expand Up @@ -47,7 +46,11 @@ def __init__(
device_map={"": device},
attn_implementation="eager",
)
self.llm.resize_token_embeddings(len(self.tokenizer))
# Avoid Transformers' mean-based resize path, which can trip on meta tensors
# during low-memory model initialization in production worker environments.
self.llm.resize_token_embeddings(
len(self.tokenizer), mean_resizing=False
)

# 3) encoder + projector (now internal)
self.encoder = TransformerCNNEncoder().to(device)
Expand Down Expand Up @@ -133,7 +136,7 @@ def enable_lora(
p.numel() for p in self.llm.parameters() if p.requires_grad
)
total_params = sum(p.numel() for p in self.llm.parameters())
print(f"✅ LoRA enabled:")
print("✅ LoRA enabled:")
print(f" LoRA parameters: {lora_params:,}")
print(f" Total trainable parameters: {trainable_params:,}")
print(f" Total parameters: {total_params:,}")
Expand Down Expand Up @@ -327,6 +330,29 @@ def generate(
)
return self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

def stream_generate(
self, batch: List[Dict[str, Any]], max_new_tokens: int = 50, **generate_kwargs
) -> Iterator[str]:
self._validate_streaming_batch(batch)

inputs_embeds, attention_mask = self.pad_and_apply_batch(batch)
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
timeout=generate_kwargs.pop("stream_timeout", None),
)

def run_generation() -> None:
self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
streamer=streamer,
**generate_kwargs,
)

yield from self._iterate_streamer(streamer, run_generation)

def compute_loss(self, batch: List[Dict[str, any]]) -> torch.Tensor:
"""
batch: same format as generate()
Expand Down Expand Up @@ -415,13 +441,6 @@ def load_lora_state_from_checkpoint(
loaded_count = 0
missing_keys = []

# Track which LoRA parameters we expect to find
expected_lora_params = {
name
for name, param in self.llm.named_parameters()
if param.requires_grad and "lora_" in name
}

for name, param in self.llm.named_parameters():
if name in lora_state and param.requires_grad and "lora_" in name:
param.data.copy_(lora_state[name])
Expand Down Expand Up @@ -504,3 +523,19 @@ def eval_prompt(
)
output = self.generate(batch, max_new_tokens=max_new_tokens)
return output[0]

def stream_prompt(
self,
prompt: FullPrompt,
max_new_tokens: int = 30000,
normalize: bool = False,
**generate_kwargs,
) -> Iterator[str]:
batch = [prompt.to_dict()]
self.eval()
batch = extend_time_series_to_match_patch_size_and_aggregate(
batch, normalize=normalize
)
yield from self.stream_generate(
batch, max_new_tokens=max_new_tokens, **generate_kwargs
)
53 changes: 51 additions & 2 deletions src/opentslm/model/llm/TimeSeriesLLM.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Dict, Any
from threading import Thread
from typing import Any, Callable, Dict, Iterator, List

# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2025 This source file is part of the OpenTSLM open-source project.
Expand All @@ -25,6 +26,13 @@ def generate(

raise NotImplementedError("Generate method should be implemented by the subclass")

def stream_generate(
self, batch: List[Dict[str, Any]], max_new_tokens: int = 50, **generate_kwargs
) -> Iterator[str]:
raise NotImplementedError(
"stream_generate method should be implemented by the subclass"
)

def compute_loss(self, batch: List[Dict[str, Any]]) -> torch.Tensor:
"""
batch: same format as generate()
Expand All @@ -36,4 +44,45 @@ def get_eos_token(self) -> str:
raise NotImplementedError("Get eos token method should be implemented by the subclass")

def eval_prompt(self, prompt: FullPrompt) -> str:
raise NotImplementedError("Eval prompt method should be implemented by the subclass")
raise NotImplementedError("Eval prompt method should be implemented by the subclass")

def stream_prompt(
self, prompt: FullPrompt, max_new_tokens: int = 1000, normalize: bool = False, **generate_kwargs
) -> Iterator[str]:
raise NotImplementedError(
"stream_prompt method should be implemented by the subclass"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

@staticmethod
def _validate_streaming_batch(batch: List[Dict[str, Any]]) -> None:
if len(batch) != 1:
raise ValueError(
"Streaming generation currently supports exactly one sample per batch."
)

@staticmethod
def _iterate_streamer(streamer: Any, generate_fn: Callable[[], None]) -> Iterator[str]:
error: BaseException | None = None

def runner() -> None:
nonlocal error
try:
generate_fn()
except BaseException as exc: # pragma: no cover - re-raised in caller
error = exc
end = getattr(streamer, "end", None)
if callable(end):
end()

thread = Thread(target=runner, daemon=True)
thread.start()

try:
for text in streamer:
if text:
yield text
finally:
thread.join()

if error is not None:
raise error
Comment thread
masquare marked this conversation as resolved.
Loading
Loading