diff --git a/.gitignore b/.gitignore index d3d9957f..14af3091 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,7 @@ evaluation/scripts/personamem # benchmarks benchmarks/ - + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -47,6 +47,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +.run # PyInstaller # Usually these files are written by a python script from a template diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 815846ed..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "python.testing.pytestArgs": [ - "tests", - "-vv" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, - "python.analysis.typeCheckingMode": "off" -} diff --git a/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml b/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml new file mode 100644 index 00000000..369ad396 --- /dev/null +++ b/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml @@ -0,0 +1,51 @@ +user_id: "root" +chat_model: + backend: "huggingface" + config: + model_name_or_path: "Qwen/Qwen3-1.7B" + temperature: 0.1 + remove_think_prefix: true + max_tokens: 4096 +mem_reader: + backend: "simple_struct" + config: + llm: + backend: "openai" + config: + model_name_or_path: "gpt-4o-mini" + temperature: 0.8 + max_tokens: 4096 + top_p: 0.9 + top_k: 50 + remove_think_prefix: true + api_key: "sk-xxxxxx" + api_base: "https://api.openai.com/v1" + embedder: + backend: "ollama" + config: + model_name_or_path: "nomic-embed-text:latest" + chunker: + backend: "sentence" + config: + tokenizer_or_token_counter: "gpt2" + chunk_size: 512 + chunk_overlap: 128 + min_sentences_per_chunk: 1 +mem_scheduler: + backend: "optimized_scheduler" + config: + top_k: 10 + act_mem_update_interval: 30 + context_window_size: 10 + thread_pool_max_workers: 10 + consume_interval_seconds: 1 + working_mem_monitor_capacity: 20 + activation_mem_monitor_capacity: 5 + enable_parallel_dispatch: true + enable_activation_memory: true +max_turns_window: 20 +top_k: 5 +enable_textual_memory: true +enable_activation_memory: true +enable_parametric_memory: false +enable_mem_scheduler: true diff --git a/examples/mem_scheduler/debug_text_mem_replace.py b/examples/mem_scheduler/debug_text_mem_replace.py new file mode 100644 index 00000000..df80f7d0 --- /dev/null +++ b/examples/mem_scheduler/debug_text_mem_replace.py @@ -0,0 +1,109 @@ +import json +import shutil +import sys + +from pathlib import Path + +from memos_w_scheduler_for_test import init_task + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.configs.mem_scheduler import AuthConfig +from memos.log import get_logger +from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) + +# Enable execution from any working directory + +logger = get_logger(__name__) + +if __name__ == "__main__": + # set up data + conversations, questions = init_task() + + # set configs + mos_config = MOSConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml" + ) + + mem_cube_config = GeneralMemCubeConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" + ) + + # default local graphdb uri + if AuthConfig.default_config_exists(): + auth_config = AuthConfig.from_local_config() + + mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key + mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url + + mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri + mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user + mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password + mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name + mem_cube_config.text_mem.config.graph_db.config.auto_create = ( + auth_config.graph_db.auto_create + ) + + # Initialization + mos = MOSForTestScheduler(mos_config) + + user_id = "user_1" + mos.create_user(user_id) + + mem_cube_id = "mem_cube_5" + mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" + + if Path(mem_cube_name_or_path).exists(): + shutil.rmtree(mem_cube_name_or_path) + print(f"{mem_cube_name_or_path} is not empty, and has been removed.") + + mem_cube = GeneralMemCube(mem_cube_config) + mem_cube.dump(mem_cube_name_or_path) + mos.register_mem_cube( + mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id + ) + + mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) + + # Add interfering conversations + file_path = Path(f"{BASE_DIR}/examples/data/mem_scheduler/scene_data.json") + scene_data = json.load(file_path.open("r", encoding="utf-8")) + mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id) + mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id) + + # Test the replace_working_memory functionality + print("\n--- Testing replace_working_memory ---") + + # Get current working memories + text_mem_base = mem_cube.text_mem + if text_mem_base is not None: + working_memories_before = text_mem_base.get_working_memory() + print(f"Working memories before replacement: {len(working_memories_before)}") + + # Create filtered memories (simulate what the scheduler would do) + # Keep only memories related to Max + filtered_memories = [working_memories_before[1], working_memories_before[4]] + + text_mem_base.replace_working_memory(memories=filtered_memories) + + # Check working memory after replacement + working_memories_after = text_mem_base.get_working_memory() + print(f"Working memories after replacement: {len(working_memories_after)}") + + if len(working_memories_after) == len(filtered_memories): + print("โœ… SUCCESS: Working memory count matches filtered memories") + else: + print( + f"โŒ FAILED: Expected {len(filtered_memories)}, got {len(working_memories_after)}" + ) + + else: + print("โŒ text_mem is None - not properly initialized") + + mos.mem_scheduler.stop() diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler.py b/examples/mem_scheduler/memos_w_optimized_scheduler.py new file mode 100644 index 00000000..fbd14536 --- /dev/null +++ b/examples/mem_scheduler/memos_w_optimized_scheduler.py @@ -0,0 +1,85 @@ +import shutil +import sys + +from pathlib import Path + +from memos_w_scheduler import init_task, show_web_logs + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.configs.mem_scheduler import AuthConfig +from memos.log import get_logger +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + +logger = get_logger(__name__) + + +def run_with_scheduler_init(): + print("==== run_with_automatic_scheduler_init ====") + conversations, questions = init_task() + + # set configs + mos_config = MOSConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml" + ) + + mem_cube_config = GeneralMemCubeConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" + ) + + # default local graphdb uri + if AuthConfig.default_config_exists(): + auth_config = AuthConfig.from_local_config() + + mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key + mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url + + mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri + mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user + mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password + mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name + mem_cube_config.text_mem.config.graph_db.config.auto_create = ( + auth_config.graph_db.auto_create + ) + + # Initialization + mos = MOS(mos_config) + + user_id = "user_1" + mos.create_user(user_id) + + mem_cube_id = "mem_cube_5" + mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" + + if Path(mem_cube_name_or_path).exists(): + shutil.rmtree(mem_cube_name_or_path) + print(f"{mem_cube_name_or_path} is not empty, and has been removed.") + + mem_cube = GeneralMemCube(mem_cube_config) + mem_cube.dump(mem_cube_name_or_path) + mos.register_mem_cube( + mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id + ) + + mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) + + for item in questions: + print("===== Chat Start =====") + query = item["question"] + print(f"Query:\n {query}\n") + response = mos.chat(query=query, user_id=user_id) + print(f"Answer:\n {response}\n") + + show_web_logs(mem_scheduler=mos.mem_scheduler) + + mos.mem_scheduler.stop() + + +if __name__ == "__main__": + run_with_scheduler_init() diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py b/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py new file mode 100644 index 00000000..9b39bf77 --- /dev/null +++ b/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py @@ -0,0 +1,87 @@ +import json +import shutil +import sys + +from pathlib import Path + +from memos_w_scheduler_for_test import init_task + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.configs.mem_scheduler import AuthConfig +from memos.log import get_logger +from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) + +# Enable execution from any working directory + +logger = get_logger(__name__) + +if __name__ == "__main__": + # set up data + conversations, questions = init_task() + + # set configs + mos_config = MOSConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml" + ) + + mem_cube_config = GeneralMemCubeConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" + ) + + # default local graphdb uri + if AuthConfig.default_config_exists(): + auth_config = AuthConfig.from_local_config() + + mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key + mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url + + mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri + mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user + mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password + mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name + mem_cube_config.text_mem.config.graph_db.config.auto_create = ( + auth_config.graph_db.auto_create + ) + + # Initialization + mos = MOSForTestScheduler(mos_config) + + user_id = "user_1" + mos.create_user(user_id) + + mem_cube_id = "mem_cube_5" + mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" + + if Path(mem_cube_name_or_path).exists(): + shutil.rmtree(mem_cube_name_or_path) + print(f"{mem_cube_name_or_path} is not empty, and has been removed.") + + mem_cube = GeneralMemCube(mem_cube_config) + mem_cube.dump(mem_cube_name_or_path) + mos.register_mem_cube( + mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id + ) + + mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) + + # Add interfering conversations + file_path = Path(f"{BASE_DIR}/examples/data/mem_scheduler/scene_data.json") + scene_data = json.load(file_path.open("r", encoding="utf-8")) + mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id) + mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id) + + for item in questions: + print("===== Chat Start =====") + query = item["question"] + print(f"Query:\n {query}\n") + response = mos.chat(query=query, user_id=user_id) + print(f"Answer:\n {response}\n") + + mos.mem_scheduler.stop() diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index 63054586..28641507 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: - from memos.mem_scheduler.schemas import ( + from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, ) diff --git a/examples/mem_scheduler/memos_w_scheduler_for_test.py b/examples/mem_scheduler/memos_w_scheduler_for_test.py index 074400ee..f710ad8c 100644 --- a/examples/mem_scheduler/memos_w_scheduler_for_test.py +++ b/examples/mem_scheduler/memos_w_scheduler_for_test.py @@ -1,6 +1,7 @@ import json import shutil import sys +import time from pathlib import Path @@ -9,7 +10,7 @@ from memos.configs.mem_scheduler import AuthConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.mos_for_test_scheduler import MOSForTestScheduler +from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler FILE_PATH = Path(__file__).absolute() @@ -19,6 +20,77 @@ logger = get_logger(__name__) +def display_memory_cube_stats(mos, user_id, mem_cube_id): + """Display detailed memory cube statistics.""" + print(f"\n๐Ÿ“Š MEMORY CUBE STATISTICS for {mem_cube_id}:") + print("-" * 60) + + mem_cube = mos.mem_cubes.get(mem_cube_id) + if not mem_cube: + print(" โŒ Memory cube not found") + return + + # Text memory stats + if mem_cube.text_mem: + text_mem = mem_cube.text_mem + working_memories = text_mem.get_working_memory() + all_memories = text_mem.get_all() + + print(" ๐Ÿ“ Text Memory:") + print(f" โ€ข Working Memory Items: {len(working_memories)}") + print( + f" โ€ข Total Memory Items: {len(all_memories) if isinstance(all_memories, list) else 'N/A'}" + ) + + if working_memories: + print(" โ€ข Working Memory Content Preview:") + for i, mem in enumerate(working_memories[:2]): + content = mem.memory[:60] + "..." if len(mem.memory) > 60 else mem.memory + print(f" {i + 1}. {content}") + + # Activation memory stats + if mem_cube.act_mem: + act_mem = mem_cube.act_mem + act_memories = list(act_mem.get_all()) + print(" โšก Activation Memory:") + print(f" โ€ข KV Cache Items: {len(act_memories)}") + if act_memories: + print( + f" โ€ข Latest Cache Size: {len(act_memories[-1].memory) if hasattr(act_memories[-1], 'memory') else 'N/A'}" + ) + + print("-" * 60) + + +def display_scheduler_status(mos): + """Display current scheduler status and configuration.""" + print("\nโš™๏ธ SCHEDULER STATUS:") + print("-" * 60) + + if not mos.mem_scheduler: + print(" โŒ Memory scheduler not initialized") + return + + scheduler = mos.mem_scheduler + print(f" ๐Ÿ”„ Scheduler Running: {scheduler._running}") + print(f" ๐Ÿ“Š Internal Queue Size: {scheduler.memos_message_queue.qsize()}") + print(f" ๐Ÿงต Parallel Dispatch: {scheduler.enable_parallel_dispatch}") + print(f" ๐Ÿ‘ฅ Max Workers: {scheduler.thread_pool_max_workers}") + print(f" โฑ๏ธ Consume Interval: {scheduler._consume_interval}s") + + if scheduler.monitor: + print(" ๐Ÿ“ˆ Monitor Active: โœ…") + print(f" ๐Ÿ—„๏ธ Database Engine: {'โœ…' if scheduler.db_engine else 'โŒ'}") + + if scheduler.dispatcher: + print(" ๐Ÿš€ Dispatcher Active: โœ…") + print( + f" ๐Ÿ”ง Dispatcher Status: {scheduler.dispatcher.status if hasattr(scheduler.dispatcher, 'status') else 'Unknown'}" + ) + + print("-" * 60) + + def init_task(): conversations = [ { @@ -83,6 +155,9 @@ def init_task(): if __name__ == "__main__": + print("๐Ÿš€ Starting Enhanced Memory Scheduler Test...") + print("=" * 80) + # set up data conversations, questions = init_task() @@ -111,6 +186,7 @@ def init_task(): ) # Initialization + print("๐Ÿ”ง Initializing MOS with Enhanced Scheduler...") mos = MOSForTestScheduler(mos_config) user_id = "user_1" @@ -121,7 +197,7 @@ def init_task(): if Path(mem_cube_name_or_path).exists(): shutil.rmtree(mem_cube_name_or_path) - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") + print(f"๐Ÿ—‘๏ธ {mem_cube_name_or_path} is not empty, and has been removed.") mem_cube = GeneralMemCube(mem_cube_config) mem_cube.dump(mem_cube_name_or_path) @@ -129,6 +205,7 @@ def init_task(): mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id ) + print("๐Ÿ“š Adding initial conversations...") mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) # Add interfering conversations @@ -137,11 +214,77 @@ def init_task(): mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id) mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id) - for item in questions: - print("===== Chat Start =====") - query = item["question"] - print(f"Query:\n {query}\n") - response = mos.chat(query=query, user_id=user_id) - print(f"Answer:\n {response}\n") + # Display initial status + print("\n๐Ÿ“Š INITIAL SYSTEM STATUS:") + display_scheduler_status(mos) + display_memory_cube_stats(mos, user_id, mem_cube_id) + + # Process questions with enhanced monitoring + print(f"\n๐ŸŽฏ Starting Question Processing ({len(questions)} questions)...") + question_start_time = time.time() + + for i, item in enumerate(questions, 1): + print(f"\n{'=' * 20} Question {i}/{len(questions)} {'=' * 20}") + print(f"๐Ÿ“ Category: {item['category']} | Difficulty: {item['difficulty']}") + print(f"๐ŸŽฏ Expected: {item['expected']}") + if "hint" in item: + print(f"๐Ÿ’ก Hint: {item['hint']}") + if "requires" in item: + print(f"๐Ÿ” Requires: {', '.join(item['requires'])}") + + print(f"\n๐Ÿš€ Processing Query: {item['question']}") + query_start_time = time.time() + + response = mos.chat(query=item["question"], user_id=user_id) + + query_time = time.time() - query_start_time + print(f"โฑ๏ธ Query Processing Time: {query_time:.3f}s") + print(f"๐Ÿค– Response: {response}") + + # Display intermediate status every 2 questions + if i % 2 == 0: + print(f"\n๐Ÿ“Š INTERMEDIATE STATUS (Question {i}):") + display_scheduler_status(mos) + display_memory_cube_stats(mos, user_id, mem_cube_id) + + total_processing_time = time.time() - question_start_time + print(f"\nโฑ๏ธ Total Question Processing Time: {total_processing_time:.3f}s") + + # Display final scheduler performance summary + print("\n" + "=" * 80) + print("๐Ÿ“Š FINAL SCHEDULER PERFORMANCE SUMMARY") + print("=" * 80) + + summary = mos.get_scheduler_summary() + print(f"๐Ÿ”ข Total Queries Processed: {summary['total_queries']}") + print(f"โšก Total Scheduler Calls: {summary['total_scheduler_calls']}") + print(f"โฑ๏ธ Average Scheduler Response Time: {summary['average_scheduler_response_time']:.3f}s") + print(f"๐Ÿง  Memory Optimizations Applied: {summary['memory_optimization_count']}") + print(f"๐Ÿ”„ Working Memory Updates: {summary['working_memory_updates']}") + print(f"โšก Activation Memory Updates: {summary['activation_memory_updates']}") + print(f"๐Ÿ“ˆ Average Query Processing Time: {summary['average_query_processing_time']:.3f}s") + + # Performance insights + print("\n๐Ÿ’ก PERFORMANCE INSIGHTS:") + if summary["total_scheduler_calls"] > 0: + optimization_rate = ( + summary["memory_optimization_count"] / summary["total_scheduler_calls"] + ) * 100 + print(f" โ€ข Memory Optimization Rate: {optimization_rate:.1f}%") + + if summary["average_scheduler_response_time"] < 0.1: + print(" โ€ข Scheduler Performance: ๐ŸŸข Excellent (< 100ms)") + elif summary["average_scheduler_response_time"] < 0.5: + print(" โ€ข Scheduler Performance: ๐ŸŸก Good (100-500ms)") + else: + print(" โ€ข Scheduler Performance: ๐Ÿ”ด Needs Improvement (> 500ms)") + + # Final system status + print("\n๐Ÿ” FINAL SYSTEM STATUS:") + display_scheduler_status(mos) + display_memory_cube_stats(mos, user_id, mem_cube_id) + + print("=" * 80) + print("๐Ÿ Test completed successfully!") mos.mem_scheduler.stop() diff --git a/examples/mem_scheduler/rabbitmq_example.py b/examples/mem_scheduler/rabbitmq_example.py index ba573238..5e40eaad 100644 --- a/examples/mem_scheduler/rabbitmq_example.py +++ b/examples/mem_scheduler/rabbitmq_example.py @@ -2,7 +2,7 @@ import time from memos.configs.mem_scheduler import AuthConfig -from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule +from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule def publish_message(rabbitmq_module, message): diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index 8c5d1415..634d69c3 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -12,8 +12,8 @@ from memos.configs.mem_scheduler import AuthConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler from memos.mem_scheduler.general_scheduler import GeneralScheduler -from memos.mem_scheduler.mos_for_test_scheduler import MOSForTestScheduler from memos.mem_scheduler.schemas.general_schemas import ( NOT_APPLICABLE_TYPE, ) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index e7cc5d65..4436e208 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -271,7 +271,7 @@ def get_mysql_config() -> dict[str, Any]: def get_scheduler_config() -> dict[str, Any]: """Get scheduler configuration.""" return { - "backend": "general_scheduler", + "backend": "optimized_scheduler", "config": { "top_k": int(os.getenv("MOS_SCHEDULER_TOP_K", "10")), "act_mem_update_interval": int( diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index 4d62bd11..a36f3e2f 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -6,7 +6,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator from memos.configs.base import BaseConfig -from memos.mem_scheduler.general_modules.misc import DictConversionMixin +from memos.mem_scheduler.general_modules.misc import DictConversionMixin, EnvConfigMixin from memos.mem_scheduler.schemas.general_schemas import ( BASE_DIR, DEFAULT_ACT_MEM_DUMP_PATH, @@ -64,6 +64,19 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): default=20, description="Capacity of the activation memory monitor" ) + # Database configuration for ORM persistence + db_path: str | None = Field( + default=None, + description="Path to SQLite database file for ORM persistence. If None, uses default scheduler_orm.db", + ) + db_url: str | None = Field( + default=None, + description="Database URL for ORM persistence (e.g., mysql://user:pass@host/db). Takes precedence over db_path", + ) + enable_orm_persistence: bool = Field( + default=True, description="Whether to enable ORM-based persistence for monitors" + ) + class SchedulerConfigFactory(BaseConfig): """Factory class for creating scheduler configurations.""" @@ -74,6 +87,7 @@ class SchedulerConfigFactory(BaseConfig): model_config = ConfigDict(extra="forbid", strict=True) backend_to_class: ClassVar[dict[str, Any]] = { "general_scheduler": GeneralSchedulerConfig, + "optimized_scheduler": GeneralSchedulerConfig, # optimized_scheduler uses same config as general_scheduler } @field_validator("backend") @@ -94,6 +108,8 @@ def create_config(self) -> "SchedulerConfigFactory": # ************************* Auth ************************* class RabbitMQConfig( BaseConfig, + DictConversionMixin, + EnvConfigMixin, ): host_name: str = Field(default="", description="Endpoint for RabbitMQ instance access") user_name: str = Field(default="", description="Static username for RabbitMQ instance") @@ -110,7 +126,7 @@ class RabbitMQConfig( ) -class GraphDBAuthConfig(BaseConfig): +class GraphDBAuthConfig(BaseConfig, DictConversionMixin, EnvConfigMixin): uri: str = Field( default="bolt://localhost:7687", description="URI for graph database access (e.g., bolt://host:port)", @@ -127,7 +143,7 @@ class GraphDBAuthConfig(BaseConfig): ) -class OpenAIConfig(BaseConfig): +class OpenAIConfig(BaseConfig, DictConversionMixin, EnvConfigMixin): api_key: str = Field(default="", description="API key for OpenAI service") base_url: str = Field(default="", description="Base URL for API endpoint") default_model: str = Field(default="", description="Default model to use") @@ -183,6 +199,25 @@ def from_local_config(cls, config_path: str | Path | None = None) -> "AuthConfig "Please use YAML (.yaml, .yml) or JSON (.json) files." ) + @classmethod + def from_local_env(cls) -> "AuthConfig": + """Creates an AuthConfig instance by loading configuration from environment variables. + + This method loads configuration for all nested components (RabbitMQ, OpenAI, GraphDB) + from their respective environment variables using each component's specific prefix. + + Returns: + AuthConfig: Configured instance with values from environment variables + + Raises: + ValueError: If any required environment variables are missing + """ + return cls( + rabbitmq=RabbitMQConfig.from_env(), + openai=OpenAIConfig.from_env(), + graph_db=GraphDBAuthConfig.from_env(), + ) + def set_openai_config_to_environment(self): # Set environment variables os.environ["OPENAI_API_KEY"] = self.openai.api_key diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index a201e22c..0d612eec 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -125,12 +125,16 @@ def _initialize_mem_scheduler(self) -> GeneralScheduler: "missing required 'llm' attribute" ) self._mem_scheduler.initialize_modules( - chat_llm=self.chat_llm, process_llm=self.chat_llm + chat_llm=self.chat_llm, + process_llm=self.chat_llm, + db_engine=self.user_manager.engine, ) else: # Configure scheduler general_modules self._mem_scheduler.initialize_modules( - chat_llm=self.chat_llm, process_llm=self.mem_reader.llm + chat_llm=self.chat_llm, + process_llm=self.mem_reader.llm, + db_engine=self.user_manager.engine, ) self._mem_scheduler.start() return self._mem_scheduler diff --git a/src/memos/mem_scheduler/analyzer/__init__.py b/src/memos/mem_scheduler/analyzer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py new file mode 100644 index 00000000..7cd085ad --- /dev/null +++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py @@ -0,0 +1,569 @@ +from datetime import datetime + +from memos.configs.mem_os import MOSConfig +from memos.log import get_logger +from memos.mem_os.main import MOS +from memos.mem_scheduler.schemas.general_schemas import ( + ANSWER_LABEL, + MONITOR_WORKING_MEMORY_TYPE, + QUERY_LABEL, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +logger = get_logger(__name__) + + +class MOSForTestScheduler(MOS): + """This class is only to test abilities of mem scheduler with enhanced monitoring""" + + def __init__(self, config: MOSConfig): + super().__init__(config) + self.memory_helpfulness_analysis = [] + + def _str_memories(self, memories: list[str]) -> str: + """Format memories for display.""" + if not memories: + return "No memories." + return "\n".join(f"{i + 1}. {memory}" for i, memory in enumerate(memories)) + + def _analyze_memory_helpfulness( + self, + query: str, + working_memories_before: list, + working_memories_after: list, + scheduler_memories: list, + ): + """Analyze how helpful each memory is for answering the current query.""" + print("\n" + "=" * 80) + print("๐Ÿง  MEMORY HELPFULNESS ANALYSIS FOR QUERY") + print("=" * 80) + + print(f"๐Ÿ“ Query: {query}") + print(f"๐Ÿ“Š Working Memories Before Scheduler: {len(working_memories_before)}") + print(f"๐Ÿ“Š Working Memories After Scheduler: {len(working_memories_after)}") + print(f"๐Ÿ“Š Working Memories from Monitor: {len(scheduler_memories)}") + + # Display working memories before scheduler (first 5 only) + if working_memories_before: + print("\n๐Ÿ”„ WORKING MEMORIES BEFORE SCHEDULER (first 5):") + for i, mem in enumerate(working_memories_before[:5]): + print(f" {i + 1}. {mem}") + + # Display working memories after scheduler (first 5 only) + if working_memories_after: + print("\n๐Ÿ”„ WORKING MEMORIES AFTER SCHEDULER (first 5):") + for i, mem in enumerate(working_memories_after[:5]): + print(f" {i + 1}. {mem}") + + # Display scheduler memories from monitor (first 5 only) + if scheduler_memories: + print("\n๐Ÿ”„ WORKING MEMORIES FROM MONITOR (first 5):") + for i, mem in enumerate(scheduler_memories[:5]): + print(f" {i + 1}. {mem}") + + # Batch assess working memory helpfulness before scheduler + if working_memories_before: + print( + f"\n๐Ÿ”„ WORKING MEMORY HELPFULNESS BEFORE SCHEDULER ({len(working_memories_before)}):" + ) + before_assessment = self._batch_assess_memories( + query, working_memories_before[:5], "before scheduler" + ) + for i, (_mem, score, reason) in enumerate(before_assessment): + print(f" {i + 1}. Helpfulness: {score}/10 - {reason}") + + # Batch assess working memory helpfulness after scheduler + if working_memories_after: + print( + f"\n๐Ÿ”„ WORKING MEMORY HELPFULNESS AFTER SCHEDULER ({len(working_memories_after)}):" + ) + after_assessment = self._batch_assess_memories( + query, working_memories_after[:5], "after scheduler" + ) + for i, (_mem, score, reason) in enumerate(after_assessment): + print(f" {i + 1}. Helpfulness: {score}/10 - {reason}") + + # Batch assess scheduler memories from monitor + if scheduler_memories: + print(f"\n๐Ÿ”„ WORKINGMEMORIES FROM MONITOR HELPFULNESS ({len(scheduler_memories)}):") + scheduler_assessment = self._batch_assess_memories( + query, scheduler_memories[:5], "from monitor" + ) + for i, (_mem, score, reason) in enumerate(scheduler_assessment): + print(f" {i + 1}. Helpfulness: {score}/10 - {reason}") + + # Overall assessment - compare before vs after vs scheduler + print("\n๐Ÿ’ก OVERALL ASSESSMENT:") + if working_memories_before and working_memories_after: + before_scores = ( + [score for _, score, _ in before_assessment] + if "before_assessment" in locals() + else [] + ) + after_scores = ( + [score for _, score, _ in after_assessment] + if "after_assessment" in locals() + else [] + ) + scheduler_scores = ( + [score for _, score, _ in scheduler_assessment] + if "scheduler_assessment" in locals() + else [] + ) + + avg_before_helpfulness = sum(before_scores) / len(before_scores) + avg_after_helpfulness = sum(after_scores) / len(after_scores) + + print(f" Average Helpfulness Before Scheduler: {avg_before_helpfulness:.1f}/10") + print(f" Average Helpfulness After Scheduler: {avg_after_helpfulness:.1f}/10") + print(f" Improvement: {avg_after_helpfulness - avg_before_helpfulness:+.1f}") + + if avg_after_helpfulness > avg_before_helpfulness: + print(" โœ… Scheduler improved working memory quality") + elif avg_after_helpfulness < avg_before_helpfulness: + print(" โŒ Scheduler decreased working memory quality") + else: + print(" โš–๏ธ Scheduler maintained working memory quality") + + # Compare scheduler memories vs working memories + + avg_scheduler_helpfulness = sum(scheduler_scores) / len(scheduler_scores) + print( + f" Average Helpfulness of Memories from Monitors: {avg_scheduler_helpfulness:.1f}/10" + ) + + if avg_scheduler_helpfulness > avg_after_helpfulness: + print(" ๐ŸŽฏ Memories from Monitors are more helpful than working memories") + elif avg_scheduler_helpfulness < avg_after_helpfulness: + print(" โš ๏ธ Working memories are more helpful than Memories from Monitors") + else: + print( + " โš–๏ธ WORKING Memories from Monitors and working memories have similar helpfulness" + ) + + # Record analysis results + self.memory_helpfulness_analysis.append( + { + "query": query, + "working_memories_before_count": len(working_memories_before), + "working_memories_after_count": len(working_memories_after), + "scheduler_memories_count": len(scheduler_memories), + "working_helpfulness_before": [score for _, score, _ in before_assessment] + if "before_assessment" in locals() + else [], + "working_helpfulness_after": [score for _, score, _ in after_assessment] + if "after_assessment" in locals() + else [], + "scheduler_helpfulness": [score for _, score, _ in scheduler_assessment] + if "scheduler_assessment" in locals() + else [], + } + ) + + print("=" * 80 + "\n") + + def _batch_assess_memories(self, query: str, memories: list, context: str) -> list: + """Use LLM to assess multiple memories at once and compare their quality.""" + try: + # Create prompt for batch assessment + memories_text = "\n".join([f"{i + 1}. {mem}" for i, mem in enumerate(memories)]) + + assessment_prompt = f""" + Task: Assess and compare the helpfulness of multiple memories for answering a query. + + Query: "{query}" + + Context: These are working memories {context}. + + Memories to assess: + {memories_text} + + Please provide: + 1. A helpfulness score from 1-10 for each memory (where 10 = extremely helpful, 1 = not helpful at all) + 2. A brief reason for each score + 3. Rank the memories from most helpful to least helpful + + Format your response as: + Memory 1: Score [number] - [reason] + Memory 2: Score [number] - [reason] + Memory 3: Score [number] - [reason] + Memory 4: Score [number] - [reason] + Memory 5: Score [number] - [reason] + + Ranking: [memory numbers in order from most to least helpful] + + Consider: + - Direct relevance to the query + - Information completeness + - How directly it answers the question + - Whether it provides useful context or background + - Compare memories against each other for relative quality + """ + + # Use the chat LLM to get batch assessment + messages = [{"role": "user", "content": assessment_prompt}] + response = self.chat_llm.generate(messages) + + # Parse the response to extract scores and reasons + assessment_results = [] + lines = response.strip().split("\n") + + for i, mem in enumerate(memories): + score = 5 # Default score + reason = "LLM assessment failed, using default score" + + # Look for the corresponding memory line + for line in lines: + if line.startswith(f"Memory {i + 1}:"): + try: + # Extract score and reason from line like "Memory 1: Score 8 - Highly relevant" + parts = line.split("Score ")[1].split(" - ", 1) + score = int(parts[0]) + score = max(1, min(10, score)) # Ensure score is 1-10 + reason = parts[1] if len(parts) > 1 else "No reason provided" + except Exception: + pass + break + + assessment_results.append((mem, score, reason)) + + return assessment_results + + except Exception as e: + logger.warning(f"LLM batch assessment failed: {e}, using fallback scoring") + # Fallback to individual assessment if batch fails + return [ + ( + mem, + self._assess_memory_helpfulness(query, mem)["score"], + self._assess_memory_helpfulness(query, mem)["reason"], + ) + for mem in memories + ] + + def _assess_memory_helpfulness(self, query: str, memory: str) -> dict: + """Use LLM to assess how helpful a memory is for answering the current query (1-10 scale)""" + try: + # Create prompt for LLM assessment + assessment_prompt = f""" + Task: Rate how helpful this memory is for answering the given query on a scale of 1-10. + + Query: "{query}" + + Memory: "{memory}" + + Please provide: + 1. A score from 1-10 (where 10 = extremely helpful, 1 = not helpful at all) + 2. A brief reason for your score + + Format your response as: + Score: [number] + Reason: [your explanation] + + Consider: + - Direct relevance to the query + - Information completeness + - How directly it answers the question + - Whether it provides useful context or background + """ + + # Use the chat LLM to get assessment + messages = [{"role": "user", "content": assessment_prompt}] + response = self.chat_llm.generate(messages) + + # Parse the response to extract score and reason + lines = response.strip().split("\n") + score = 5 # Default score + reason = "LLM assessment failed, using default score" + + for line in lines: + if line.startswith("Score:"): + try: + score_text = line.split(":")[1].strip() + score = int(score_text) + score = max(1, min(10, score)) # Ensure score is 1-10 + except Exception: + pass + elif line.startswith("Reason:"): + reason = line.split(":", 1)[1].strip() + + return {"score": score, "reason": reason} + + except Exception as e: + logger.warning(f"LLM assessment failed: {e}, using fallback scoring") + # Fallback to simple keyword matching if LLM fails + return self._fallback_memory_assessment(query, memory) + + def _fallback_memory_assessment(self, query: str, memory: str) -> dict: + """Fallback assessment method using keyword matching if LLM fails""" + query_lower = query.lower() + memory_lower = memory.lower() + + # Keyword matching + query_words = set(query_lower.split()) + memory_words = set(memory_lower.split()) + common_words = query_words.intersection(memory_words) + + # Semantic relevance scoring + score = 0 + + # Exact keyword matches (highest weight) + if len(common_words) > 0: + score += min(len(common_words) * 2, 6) + + # Partial matches (medium weight) + partial_matches = sum( + 1 for qw in query_words for mw in memory_words if qw in mw or mw in qw + ) + if partial_matches > 0: + score += min(partial_matches, 3) + + # Topic relevance (through common topic words) + topic_words = [ + "problem", + "solution", + "answer", + "method", + "reason", + "result", + "analysis", + "compare", + "explain", + ] + topic_matches = sum(1 for topic in topic_words if topic in memory_lower) + score += topic_matches + + # Ensure score is 1-10 + score = max(1, min(10, score)) + + # Determine helpfulness level + if score >= 8: + reason = "Highly relevant, directly answers the query" + elif score >= 6: + reason = "Relevant, provides useful information" + elif score >= 4: + reason = "Partially relevant, somewhat helpful" + elif score >= 2: + reason = "Low relevance, limited help" + else: + reason = "Very low relevance, minimal help" + + return {"score": score, "reason": reason} + + def _assess_ranking_quality(self, rank: int, helpfulness: int) -> str: + """Use LLM to assess whether the memory ranking is reasonable""" + try: + # Create prompt for LLM ranking assessment + ranking_prompt = f""" + Task: Assess whether this memory ranking is reasonable. + + Context: A memory with helpfulness score {helpfulness}/10 is ranked at position {rank}. + + Please evaluate if this ranking makes sense and provide a brief assessment. + + Consider: + - Higher helpfulness scores should generally rank higher + - Rank 1 should typically have the highest helpfulness + - The relationship between rank and helpfulness + + Provide a brief assessment in one sentence. + """ + + # Use the chat LLM to get assessment + messages = [{"role": "user", "content": ranking_prompt}] + response = self.chat_llm.generate(messages) + + return response.strip() + + except Exception as e: + logger.warning(f"LLM ranking assessment failed: {e}, using fallback assessment") + # Fallback assessment + if rank == 1 and helpfulness >= 8: + return "โœ… Ranking is reasonable - most helpful memory ranked first" + elif rank == 1 and helpfulness <= 4: + return "โŒ Ranking is unreasonable - first ranked memory has low helpfulness" + elif rank <= 3 and helpfulness >= 6: + return "โœ… Ranking is reasonable - high helpfulness memory ranked high" + elif rank <= 3 and helpfulness <= 3: + return "โš ๏ธ Ranking may be unreasonable - low helpfulness memory ranked high" + elif rank > 3 and helpfulness >= 7: + return "โš ๏ธ Ranking may be unreasonable - high helpfulness memory ranked low" + else: + return "๐ŸŸก Ranking is acceptable - helpfulness and rank generally match" + + def chat(self, query: str, user_id: str | None = None) -> str: + """ + Chat with the MOS with memory helpfulness analysis. + + Args: + query (str): The user's query. + user_id (str | None): The user ID. + + Returns: + str: The response from the MOS. + """ + target_user_id = user_id if user_id is not None else self.user_id + accessible_cubes = self.user_manager.get_user_cubes(target_user_id) + user_cube_ids = [cube.cube_id for cube in accessible_cubes] + + if target_user_id not in self.chat_history_manager: + self._register_chat_history(target_user_id) + + chat_history = self.chat_history_manager[target_user_id] + topk_for_scheduler = 2 + + if self.config.enable_textual_memory and self.mem_cubes: + memories_all = [] + for mem_cube_id, mem_cube in self.mem_cubes.items(): + if mem_cube_id not in user_cube_ids: + continue + if not mem_cube.text_mem: + continue + + # Get working memories BEFORE scheduler + working_memories_before = [m.memory for m in mem_cube.text_mem.get_working_memory()] + + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=QUERY_LABEL, + content=query, + timestamp=datetime.now(), + ) + + print(f"\n๐Ÿš€ Starting Scheduler for {mem_cube_id}...") + + # Force scheduler to run immediately + self.mem_scheduler.monitor.query_trigger_interval = 0 + self.mem_scheduler._query_message_consumer(messages=[message_item]) + + # Get scheduler memories + scheduler_memories = self.mem_scheduler.monitor.get_monitor_memories( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + memory_type=MONITOR_WORKING_MEMORY_TYPE, + top_k=20, + ) + + # Get working memories AFTER scheduler + working_memories_after = [m.memory for m in mem_cube.text_mem.get_working_memory()] + + # Get mem_cube memories for response generation + memories = mem_cube.text_mem.search( + query, + top_k=self.config.top_k - topk_for_scheduler, + info={ + "user_id": target_user_id, + "session_id": self.session_id, + "chat_history": chat_history.chat_history, + }, + ) + text_memories = [m.memory for m in memories] + + # Analyze memory helpfulness - compare before vs after vs scheduler + self._analyze_memory_helpfulness( + query, working_memories_before, working_memories_after, scheduler_memories + ) + + # Combine all memories for response generation + memories_all.extend(scheduler_memories[:topk_for_scheduler]) + memories_all.extend(text_memories) + memories_all = list(set(memories_all)) + + logger.info(f"๐Ÿง  [Memory] Searched memories:\n{self._str_memories(memories_all)}\n") + system_prompt = self._build_system_prompt(memories_all) + else: + system_prompt = self._build_system_prompt() + + current_messages = [ + {"role": "system", "content": system_prompt}, + *chat_history.chat_history, + {"role": "user", "content": query}, + ] + past_key_values = None + + if self.config.enable_activation_memory: + assert self.config.chat_model.backend == "huggingface", ( + "Activation memory only used for huggingface backend." + ) + # TODO this only one cubes + for mem_cube_id, mem_cube in self.mem_cubes.items(): + if mem_cube_id not in user_cube_ids: + continue + if mem_cube.act_mem: + kv_cache = next(iter(mem_cube.act_mem.get_all()), None) + past_key_values = ( + kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None + ) + break + # Generate response + response = self.chat_llm.generate(current_messages, past_key_values=past_key_values) + else: + response = self.chat_llm.generate(current_messages) + + logger.info(f"๐Ÿค– [Assistant] {response}\n") + chat_history.chat_history.append({"role": "user", "content": query}) + chat_history.chat_history.append({"role": "assistant", "content": response}) + self.chat_history_manager[user_id] = chat_history + + # Submit message to scheduler for answer processing + for accessible_mem_cube in accessible_cubes: + mem_cube_id = accessible_mem_cube.cube_id + mem_cube = self.mem_cubes[mem_cube_id] + if self.enable_mem_scheduler and self.mem_scheduler is not None: + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ANSWER_LABEL, + content=response, + timestamp=datetime.now(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + + return response + + def get_memory_helpfulness_summary(self) -> dict: + """Get summary of memory helpfulness analysis.""" + if not self.memory_helpfulness_analysis: + return {"message": "No memory helpfulness analysis data available"} + + total_queries = len(self.memory_helpfulness_analysis) + + # Calculate average helpfulness for working memories before scheduler + before_scores = [] + for analysis in self.memory_helpfulness_analysis: + before_scores.extend(analysis["working_helpfulness_before"]) + + # Calculate average helpfulness for working memories after scheduler + after_scores = [] + for analysis in self.memory_helpfulness_analysis: + after_scores.extend(analysis["working_helpfulness_after"]) + + # Calculate average helpfulness for scheduler memories from monitor + scheduler_scores = [] + for analysis in self.memory_helpfulness_analysis: + scheduler_scores.extend(analysis["scheduler_helpfulness"]) + + avg_before_helpfulness = sum(before_scores) / len(before_scores) if before_scores else 0 + avg_after_helpfulness = sum(after_scores) / len(after_scores) if after_scores else 0 + avg_scheduler_helpfulness = ( + sum(scheduler_scores) / len(scheduler_scores) if scheduler_scores else 0 + ) + + return { + "total_queries": total_queries, + "working_memories_before_analyzed": len(before_scores), + "working_memories_after_analyzed": len(after_scores), + "scheduler_memories_analyzed": len(scheduler_scores), + "average_helpfulness_before_scheduler": f"{avg_before_helpfulness:.1f}/10", + "average_helpfulness_after_scheduler": f"{avg_after_helpfulness:.1f}/10", + "average_helpfulness_scheduler_memories": f"{avg_scheduler_helpfulness:.1f}/10", + "overall_improvement": f"{avg_after_helpfulness - avg_before_helpfulness:+.1f}", + "improvement_percentage": f"{((avg_after_helpfulness - avg_before_helpfulness) / avg_before_helpfulness * 100):+.1f}%" + if avg_before_helpfulness > 0 + else "N/A", + "scheduler_vs_working_comparison": f"{avg_scheduler_helpfulness - avg_after_helpfulness:+.1f}", + } diff --git a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py new file mode 100644 index 00000000..87876061 --- /dev/null +++ b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from memos.log import get_logger +from memos.mem_scheduler.general_scheduler import GeneralScheduler +from memos.mem_scheduler.schemas.general_schemas import ( + DEFAULT_MAX_QUERY_KEY_WORDS, + UserID, +) +from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem + + +if TYPE_CHECKING: + from memos.memories.textual.tree import TextualMemoryItem + + +logger = get_logger(__name__) + + +class SchedulerForEval(GeneralScheduler): + """ + A scheduler class that inherits from GeneralScheduler and provides evaluation-specific functionality. + This class extends GeneralScheduler with evaluation methods. + """ + + def __init__(self, config): + """ + Initialize the SchedulerForEval with the same configuration as GeneralScheduler. + + Args: + config: Configuration object for the scheduler + """ + super().__init__(config) + + def update_working_memory_for_eval( + self, query: str, user_id: UserID | str, top_k: int + ) -> list[str]: + """ + Update working memory based on query and return the updated memory list. + + Args: + query: The query string + user_id: User identifier + top_k: Number of top memories to return + + Returns: + List of memory strings from updated working memory + """ + self.monitor.register_query_monitor_if_not_exists( + user_id=user_id, mem_cube_id=self.current_mem_cube_id + ) + + query_keywords = self.monitor.extract_query_keywords(query=query) + logger.info(f'Extract keywords "{query_keywords}" from query "{query}"') + + item = QueryMonitorItem( + user_id=user_id, + mem_cube_id=self.current_mem_cube_id, + query_text=query, + keywords=query_keywords, + max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, + ) + query_db_manager = self.monitor.query_monitors[user_id][self.current_mem_cube_id] + query_db_manager.obj.put(item=item) + # Sync with database after adding new item + query_db_manager.sync_with_orm() + logger.debug(f"Queries in monitor are {query_db_manager.obj.get_queries_with_timesort()}.") + + queries = [query] + + # recall + mem_cube = self.current_mem_cube + text_mem_base = mem_cube.text_mem + + cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() + text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] + intent_result = self.monitor.detect_intent( + q_list=queries, text_working_memory=text_working_memory + ) + + if intent_result["trigger_retrieval"]: + missing_evidences = intent_result["missing_evidences"] + num_evidence = len(missing_evidences) + k_per_evidence = max(1, top_k // max(1, num_evidence)) + new_candidates = [] + for item in missing_evidences: + logger.info(f"missing_evidences: {item}") + results: list[TextualMemoryItem] = self.retriever.search( + query=item, + mem_cube=mem_cube, + top_k=k_per_evidence, + method=self.search_method, + ) + logger.info( + f"search results for {missing_evidences}: {[one.memory for one in results]}" + ) + new_candidates.extend(results) + print( + f"missing_evidences: {missing_evidences} and get {len(new_candidates)} new candidate memories." + ) + else: + new_candidates = [] + print(f"intent_result: {intent_result}. not triggered") + + # rerank + new_order_working_memory = self.replace_working_memory( + user_id=user_id, + mem_cube_id=self.current_mem_cube_id, + mem_cube=self.current_mem_cube, + original_memory=cur_working_memory, + new_memory=new_candidates, + ) + new_order_working_memory = new_order_working_memory[:top_k] + logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}") + + return [m.memory for m in new_order_working_memory] + + def evaluate_query_with_memories( + self, query: str, memory_texts: list[str], user_id: UserID | str + ) -> bool: + """ + Use LLM to evaluate whether the given memories can answer the query. + + Args: + query: The query string to evaluate + memory_texts: List of memory texts to check against + user_id: User identifier + + Returns: + Boolean indicating whether the memories can answer the query + """ + queries = [query] + intent_result = self.monitor.detect_intent(q_list=queries, text_working_memory=memory_texts) + return intent_result["trigger_retrieval"] + + def search_for_eval( + self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True + ) -> tuple[list[str], bool]: + """ + Original search_for_eval function refactored to use the new decomposed functions. + + Args: + query: The query string + user_id: User identifier + top_k: Number of top memories to return + scheduler_flag: Whether to update working memory or just evaluate + + Returns: + Tuple of (memory_list, can_answer_boolean) + """ + if not scheduler_flag: + # Get current working memory without updating + mem_cube = self.current_mem_cube + text_mem_base = mem_cube.text_mem + cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() + text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] + + # Use the evaluation function to check if memories can answer the query + can_answer = self.evaluate_query_with_memories( + query=query, memory_texts=text_working_memory, user_id=user_id + ) + return text_working_memory, can_answer + else: + # Update working memory and get the result + updated_memories = self.update_working_memory_for_eval( + query=query, user_id=user_id, top_k=top_k + ) + + # Use the evaluation function to check if memories can answer the query + can_answer = self.evaluate_query_with_memories( + query=query, memory_texts=updated_memories, user_id=user_id + ) + return updated_memories, can_answer diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 44bc7da3..b6ef00d8 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -5,16 +5,16 @@ from datetime import datetime from pathlib import Path +from sqlalchemy.engine import Engine + from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue -from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule -from memos.mem_scheduler.general_modules.redis_service import RedisSchedulerModule -from memos.mem_scheduler.general_modules.retriever import SchedulerRetriever from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule +from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor from memos.mem_scheduler.schemas.general_schemas import ( @@ -33,6 +33,8 @@ from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, ) +from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule +from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory @@ -62,6 +64,7 @@ def __init__(self, config: BaseSchedulerConfig): ) self.retriever: SchedulerRetriever | None = None + self.db_engine: Engine | None = None self.monitor: SchedulerGeneralMonitor | None = None self.dispatcher_monitor: SchedulerDispatcherMonitor | None = None self.dispatcher = SchedulerDispatcher( @@ -70,12 +73,15 @@ def __init__(self, config: BaseSchedulerConfig): ) # internal message queue - self.max_internal_messae_queue_size = 100 + self.max_internal_message_queue_size = self.config.get( + "max_internal_message_queue_size", 100 + ) self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( - maxsize=self.max_internal_messae_queue_size + maxsize=self.max_internal_message_queue_size ) + self.max_web_log_queue_size = self.config.get("max_web_log_queue_size", 50) self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue( - maxsize=self.max_internal_messae_queue_size + maxsize=self.max_web_log_queue_size ) self._consumer_thread = None # Reference to our consumer thread self._running = False @@ -92,34 +98,57 @@ def __init__(self, config: BaseSchedulerConfig): self.auth_config = None self.rabbitmq_config = None - def initialize_modules(self, chat_llm: BaseLLM, process_llm: BaseLLM | None = None): + def initialize_modules( + self, + chat_llm: BaseLLM, + process_llm: BaseLLM | None = None, + db_engine: Engine | None = None, + ): if process_llm is None: process_llm = chat_llm - # initialize submodules - self.chat_llm = chat_llm - self.process_llm = process_llm - self.monitor = SchedulerGeneralMonitor(process_llm=self.process_llm, config=self.config) - self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config) - self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) + try: + # initialize submodules + self.chat_llm = chat_llm + self.process_llm = process_llm + self.db_engine = db_engine + self.monitor = SchedulerGeneralMonitor( + process_llm=self.process_llm, config=self.config, db_engine=self.db_engine + ) + self.db_engine = self.monitor.db_engine + self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config) + self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) + + if self.enable_parallel_dispatch: + self.dispatcher_monitor.initialize(dispatcher=self.dispatcher) + self.dispatcher_monitor.start() + + # initialize with auth_config + if self.auth_config_path is not None and Path(self.auth_config_path).exists(): + self.auth_config = AuthConfig.from_local_config(config_path=self.auth_config_path) + elif AuthConfig.default_config_exists(): + self.auth_config = AuthConfig.from_local_config() + else: + self.auth_config = AuthConfig.from_local_env() - if self.enable_parallel_dispatch: - self.dispatcher_monitor.initialize(dispatcher=self.dispatcher) - self.dispatcher_monitor.start() - - # initialize with auth_cofig - if self.auth_config_path is not None and Path(self.auth_config_path).exists(): - self.auth_config = AuthConfig.from_local_config(config_path=self.auth_config_path) - elif AuthConfig.default_config_exists(): - self.auth_config = AuthConfig.from_local_config() - else: - self.auth_config = None + if self.auth_config is not None: + self.rabbitmq_config = self.auth_config.rabbitmq + self.initialize_rabbitmq(config=self.rabbitmq_config) - if self.auth_config is not None: - self.rabbitmq_config = self.auth_config.rabbitmq - self.initialize_rabbitmq(config=self.rabbitmq_config) + logger.debug("GeneralScheduler has been initialized") + except Exception as e: + logger.error(f"Failed to initialize scheduler modules: {e}", exc_info=True) + # Clean up any partially initialized resources + self._cleanup_on_init_failure() + raise - logger.debug("GeneralScheduler has been initialized") + def _cleanup_on_init_failure(self): + """Clean up resources if initialization fails.""" + try: + if hasattr(self, "dispatcher_monitor") and self.dispatcher_monitor is not None: + self.dispatcher_monitor.stop() + except Exception as e: + logger.warning(f"Error during cleanup: {e}") @property def mem_cube(self) -> GeneralMemCube: @@ -200,8 +229,11 @@ def replace_working_memory( text_mem_base: TreeTextMemory = text_mem_base # process rerank memories with llm - query_monitor = self.monitor.query_monitors[user_id][mem_cube_id] - query_history = query_monitor.get_queries_with_timesort() + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + # Sync with database to get latest query history + query_db_manager.sync_with_orm() + + query_history = query_db_manager.obj.get_queries_with_timesort() memories_with_new_order, rerank_success_flag = ( self.retriever.process_and_rerank_memories( queries=query_history, @@ -211,8 +243,27 @@ def replace_working_memory( ) ) - # update working memory monitors - query_keywords = query_monitor.get_keywords_collections() + # Filter completely unrelated memories according to query_history + logger.info(f"Filtering memories based on query history: {len(query_history)} queries") + filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, + memories=memories_with_new_order, + ) + + if filter_success_flag: + logger.info( + f"Memory filtering completed successfully. " + f"Filtered from {len(memories_with_new_order)} to {len(filtered_memories)} memories" + ) + memories_with_new_order = filtered_memories + else: + logger.warning( + "Memory filtering failed - keeping all memories as fallback. " + f"Original count: {len(memories_with_new_order)}" + ) + + # Update working memory monitors + query_keywords = query_db_manager.obj.get_keywords_collections() logger.info( f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords" ) @@ -235,7 +286,7 @@ def replace_working_memory( mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][ mem_cube_id - ].get_sorted_mem_monitors(reverse=True) + ].obj.get_sorted_mem_monitors(reverse=True) new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors] text_mem_base.replace_working_memory(memories=new_working_memories) @@ -278,6 +329,7 @@ def update_activation_memory( new_text_memories = new_memories else: logger.error("Not Implemented.") + return try: if isinstance(mem_cube.act_mem, VLLMKVCacheMemory): @@ -333,7 +385,9 @@ def update_activation_memory( ) except Exception as e: - logger.warning(f"MOS-based activation memory update failed: {e}", exc_info=True) + logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True) + # Re-raise the exception if it's critical for the operation + # For now, we'll continue execution but this should be reviewed def update_activation_memory_periodically( self, @@ -358,7 +412,8 @@ def update_activation_memory_periodically( if ( user_id not in self.monitor.working_memory_monitors or mem_cube_id not in self.monitor.working_memory_monitors[user_id] - or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) == 0 + or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories) + == 0 ): logger.warning( "No memories found in working_memory_monitors, activation memory update is skipped" @@ -369,9 +424,13 @@ def update_activation_memory_periodically( user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube ) + # Sync with database to get latest activation memories + activation_db_manager = self.monitor.activation_memory_monitors[user_id][ + mem_cube_id + ] + activation_db_manager.sync_with_orm() new_activation_memories = [ - m.memory_text - for m in self.monitor.activation_memory_monitors[user_id][mem_cube_id].memories + m.memory_text for m in activation_db_manager.obj.memories ] logger.info( @@ -412,6 +471,11 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt messages = [messages] # transform single message to list for message in messages: + if not isinstance(message, ScheduleMessageItem): + error_msg = f"Invalid message type: {type(message)}, expected ScheduleMessageItem" + logger.error(error_msg) + raise TypeError(error_msg) + self.memos_message_queue.put(message) logger.info(f"Submitted message: {message.label} - {message.content}") @@ -427,6 +491,11 @@ def _submit_web_logs( messages = [messages] # transform single message to list for message in messages: + if not isinstance(message, ScheduleLogForWebItem): + error_msg = f"Invalid message type: {type(message)}, expected ScheduleLogForWebItem" + logger.error(error_msg) + raise TypeError(error_msg) + self._web_log_message_queue.put(message) message_info = message.debug_info() logger.debug(f"Submitted Scheduling log for web: {message_info}") @@ -461,25 +530,26 @@ def _message_consumer(self) -> None: """ while self._running: # Use a running flag for graceful shutdown try: - # Check if queue has messages (non-blocking) - if not self.memos_message_queue.empty(): - # Get all available messages at once - messages = [] - while not self.memos_message_queue.empty(): - try: - messages.append(self.memos_message_queue.get_nowait()) - except queue.Empty: - break - - if messages: - try: - self.dispatcher.dispatch(messages) - except Exception as e: - logger.error(f"Error dispatching messages: {e!s}") - finally: - # Mark all messages as processed - for _ in messages: - self.memos_message_queue.task_done() + # Get all available messages at once (thread-safe approach) + messages = [] + while True: + try: + # Use get_nowait() directly without empty() check to avoid race conditions + message = self.memos_message_queue.get_nowait() + messages.append(message) + except queue.Empty: + # No more messages available + break + + if messages: + try: + self.dispatcher.dispatch(messages) + except Exception as e: + logger.error(f"Error dispatching messages: {e!s}") + finally: + # Mark all messages as processed + for _ in messages: + self.memos_message_queue.task_done() # Sleep briefly to prevent busy waiting time.sleep(self._consume_interval) # Adjust interval as needed diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py index 41ebdfd4..3c7116b7 100644 --- a/src/memos/mem_scheduler/general_modules/misc.py +++ b/src/memos/mem_scheduler/general_modules/misc.py @@ -1,9 +1,10 @@ import json +import os from contextlib import suppress from datetime import datetime from queue import Empty, Full, Queue -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from pydantic import field_serializer @@ -16,6 +17,75 @@ BaseModelType = TypeVar("T", bound="BaseModel") +class EnvConfigMixin(Generic[T]): + """Abstract base class for environment variable configuration.""" + + ENV_PREFIX = "MEMSCHEDULER_" + + @classmethod + def get_env_prefix(cls) -> str: + """Automatically generates environment variable prefix from class name. + + Converts the class name to uppercase and appends an underscore. + If the class name ends with 'Config', that suffix is removed first. + + Examples: + RabbitMQConfig -> "RABBITMQ_" + OpenAIConfig -> "OPENAI_" + GraphDBAuthConfig -> "GRAPH_DB_AUTH_" + """ + class_name = cls.__name__ + # Remove 'Config' suffix if present + if class_name.endswith("Config"): + class_name = class_name[:-6] + # Convert to uppercase and add trailing underscore + + return f"{cls.ENV_PREFIX}{class_name.upper()}_" + + @classmethod + def from_env(cls: type[T]) -> T: + """Creates a config instance from environment variables. + + Reads all environment variables with the class-specific prefix and maps them + to corresponding configuration fields (converting to the appropriate types). + + Returns: + An instance of the config class populated from environment variables. + + Raises: + ValueError: If required environment variables are missing. + """ + prefix = cls.get_env_prefix() + field_values = {} + + for field_name, field_info in cls.model_fields.items(): + env_var = f"{prefix}{field_name.upper()}" + field_type = field_info.annotation + + if field_info.is_required() and env_var not in os.environ: + raise ValueError(f"Required environment variable {env_var} is missing") + + if env_var in os.environ: + raw_value = os.environ[env_var] + field_values[field_name] = cls._parse_env_value(raw_value, field_type) + elif field_info.default is not None: + field_values[field_name] = field_info.default + else: + raise ValueError() + return cls(**field_values) + + @classmethod + def _parse_env_value(cls, value: str, target_type: type) -> Any: + """Converts environment variable string to appropriate type.""" + if target_type is bool: + return value.lower() in ("true", "1", "t", "y", "yes") + if target_type is int: + return int(value) + if target_type is float: + return float(value) + return value + + class DictConversionMixin: """ Provides conversion functionality between Pydantic models and dictionaries, @@ -44,6 +114,26 @@ def to_dict(self) -> dict: dump_data["timestamp"] = self.serialize_datetime(self.timestamp, None) return dump_data + def to_json(self, **kwargs) -> str: + """ + Convert model instance to a JSON string. + - Accepts the same kwargs as json.dumps (e.g., indent, ensure_ascii) + - Default settings make JSON human-readable and UTF-8 safe + """ + return json.dumps(self.to_dict(), ensure_ascii=False, default=lambda o: str(o), **kwargs) + + @classmethod + def from_json(cls: type[BaseModelType], json_str: str) -> BaseModelType: + """ + Create model instance from a JSON string. + - Parses JSON into a dictionary and delegates to from_dict + """ + try: + data = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON string: {e}") from e + return cls.from_dict(data) + @classmethod def from_dict(cls: type[BaseModelType], data: dict) -> BaseModelType: """ @@ -102,3 +192,11 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non def get_queue_content_without_pop(self) -> list[T]: """Return a copy of the queue's contents without modifying it.""" return list(self.queue) + + def clear(self) -> None: + """Remove all items from the queue. + + This operation is thread-safe. + """ + with self.mutex: + self.queue.clear() diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 0aa66707..44e74453 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -69,7 +69,7 @@ def create_autofilled_log_item( and mem_cube_id in self.monitor.activation_memory_monitors[user_id] ): activation_monitor = self.monitor.activation_memory_monitors[user_id][mem_cube_id] - transformed_act_memory_size = len(activation_monitor.memories) + transformed_act_memory_size = len(activation_monitor.obj.memories) logger.info( f'activation_memory_monitors currently has "{transformed_act_memory_size}" transformed memory size' ) @@ -98,6 +98,7 @@ def create_autofilled_log_item( ) return log_message + # TODO: ๆ—ฅๅฟ—ๆ‰“ๅ‡บๆฅๆ•ฐ้‡ไธๅฏน @log_exceptions(logger=logger) def log_working_memory_replacement( self, @@ -125,6 +126,7 @@ def log_working_memory_replacement( added_memories = list(new_set - original_set) # Present in new but not original # recording messages + log_messages = [] for memory in added_memories: normalized_mem = transform_name_to_key(name=memory) if normalized_mem not in memory_type_map: @@ -145,11 +147,13 @@ def log_working_memory_replacement( mem_cube_id=mem_cube_id, mem_cube=mem_cube, ) - log_func_callback([log_message]) - logger.info( - f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) " - f"transformed to {WORKING_MEMORY_TYPE} memories." - ) + log_messages.append(log_message) + + logger.info( + f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) " + f"transformed to {WORKING_MEMORY_TYPE} memories." + ) + log_func_callback(log_messages) @log_exceptions(logger=logger) def log_activation_memory_update( @@ -170,6 +174,7 @@ def log_activation_memory_update( added_memories = list(new_set - original_set) # Present in new but not original # recording messages + log_messages = [] for mem in added_memories: log_message_a = self.create_autofilled_log_item( log_content=mem, @@ -194,12 +199,13 @@ def log_activation_memory_update( mem_cube_id=mem_cube_id, mem_cube=mem_cube, ) - logger.info( - f"{len(added_memories)} {ACTIVATION_MEMORY_TYPE} memorie(s) " - f"transformed to {PARAMETER_MEMORY_TYPE} memories." - ) - log_func_callback([log_message_a, log_message_b]) + log_messages.extend([log_message_a, log_message_b]) + logger.info( + f"{len(added_memories)} {ACTIVATION_MEMORY_TYPE} memorie(s) " + f"transformed to {PARAMETER_MEMORY_TYPE} memories." + ) + log_func_callback(log_messages) @log_exceptions(logger=logger) def log_adding_memory( diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 08293886..2c914ff3 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -27,6 +27,8 @@ def __init__(self, config: GeneralSchedulerConfig): """Initialize the scheduler with the given configuration.""" super().__init__(config) + self.query_key_words_limit = self.config.get("query_key_words_limit", 20) + # register handlers handlers = { QUERY_LABEL: self._query_message_consumer, @@ -35,25 +37,43 @@ def __init__(self, config: GeneralSchedulerConfig): } self.dispatcher.register_handlers(handlers) - # for evaluation - def search_for_eval( - self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True - ) -> (list[str], bool): + def update_working_memory_for_eval( + self, query: str, user_id: UserID | str, top_k: int + ) -> list[str]: + """ + Update working memory based on query and return the updated memory list. + + Args: + query: The query string + user_id: User identifier + top_k: Number of top memories to return + + Returns: + List of memory strings from updated working memory + """ self.monitor.register_query_monitor_if_not_exists( user_id=user_id, mem_cube_id=self.current_mem_cube_id ) query_keywords = self.monitor.extract_query_keywords(query=query) - logger.info(f'Extract keywords "{query_keywords}" from query "{query}"') + logger.info( + f'Extracted keywords "{query_keywords}" from query "{query}" for user_id={user_id}' + ) item = QueryMonitorItem( + user_id=user_id, + mem_cube_id=self.current_mem_cube_id, query_text=query, keywords=query_keywords, max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, ) - query_monitor = self.monitor.query_monitors[user_id][self.current_mem_cube_id] - query_monitor.put(item=item) - logger.debug(f"Queries in monitor are {query_monitor.get_queries_with_timesort()}.") + query_db_manager = self.monitor.query_monitors[user_id][self.current_mem_cube_id] + query_db_manager.obj.put(item=item) + # Sync with database after adding new item + query_db_manager.sync_with_orm() + logger.debug( + f"Queries in monitor for user_id={user_id}, mem_cube_id={self.current_mem_cube_id}: {query_db_manager.obj.get_queries_with_timesort()}" + ) queries = [query] @@ -67,45 +87,95 @@ def search_for_eval( q_list=queries, text_working_memory=text_working_memory ) + if intent_result["trigger_retrieval"]: + missing_evidences = intent_result["missing_evidences"] + num_evidence = len(missing_evidences) + k_per_evidence = max(1, top_k // max(1, num_evidence)) + new_candidates = [] + for item in missing_evidences: + logger.info(f"Searching for missing evidence: '{item}' with top_k={k_per_evidence}") + results: list[TextualMemoryItem] = self.retriever.search( + query=item, + mem_cube=mem_cube, + top_k=k_per_evidence, + method=self.search_method, + ) + logger.info( + f"Search results for missing evidence '{item}': {[one.memory for one in results]}" + ) + new_candidates.extend(results) + print( + f"Missing evidences: {missing_evidences} -> Retrieved {len(new_candidates)} new candidate memories for user_id={user_id}" + ) + else: + new_candidates = [] + print( + f"Intent detection result: {intent_result} -> Retrieval not triggered for user_id={user_id}" + ) + + # rerank + new_order_working_memory = self.replace_working_memory( + user_id=user_id, + mem_cube_id=self.current_mem_cube_id, + mem_cube=self.current_mem_cube, + original_memory=cur_working_memory, + new_memory=new_candidates, + ) + new_order_working_memory = new_order_working_memory[:top_k] + logger.info( + f"Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}" + ) + + return [m.memory for m in new_order_working_memory] + + def evaluate_query_with_memories( + self, query: str, memory_texts: list[str], user_id: UserID | str + ) -> bool: + """ + Use LLM to evaluate whether the given memories can answer the query. + + Args: + query: The query string to evaluate + memory_texts: List of memory texts to check against + user_id: User identifier + + Returns: + Boolean indicating whether the memories can answer the query + """ + queries = [query] + intent_result = self.monitor.detect_intent(q_list=queries, text_working_memory=memory_texts) + return intent_result["trigger_retrieval"] + + # for evaluation + def search_for_eval( + self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True + ) -> (list[str], bool): + """ + Original search_for_eval function refactored to use the new decomposed functions. + """ if not scheduler_flag: - return text_working_memory, intent_result["trigger_retrieval"] + # Get current working memory without updating + mem_cube = self.current_mem_cube + text_mem_base = mem_cube.text_mem + cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() + text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] + + # Use the evaluation function to check if memories can answer the query + can_answer = self.evaluate_query_with_memories( + query=query, memory_texts=text_working_memory, user_id=user_id + ) + return text_working_memory, can_answer else: - if intent_result["trigger_retrieval"]: - missing_evidences = intent_result["missing_evidences"] - num_evidence = len(missing_evidences) - k_per_evidence = max(1, top_k // max(1, num_evidence)) - new_candidates = [] - for item in missing_evidences: - logger.info(f"missing_evidences: {item}") - results: list[TextualMemoryItem] = self.retriever.search( - query=item, - mem_cube=mem_cube, - top_k=k_per_evidence, - method=self.search_method, - ) - logger.info( - f"search results for {missing_evidences}: {[one.memory for one in results]}" - ) - new_candidates.extend(results) - print( - f"missing_evidences: {missing_evidences} and get {len(new_candidates)} new candidate memories." - ) - else: - new_candidates = [] - print(f"intent_result: {intent_result}. not triggered") - - # rerank - new_order_working_memory = self.replace_working_memory( - user_id=user_id, - mem_cube_id=self.current_mem_cube_id, - mem_cube=self.current_mem_cube, - original_memory=cur_working_memory, - new_memory=new_candidates, + # Update working memory and get the result + updated_memories = self.update_working_memory_for_eval( + query=query, user_id=user_id, top_k=top_k ) - new_order_working_memory = new_order_working_memory[:top_k] - logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}") - return [m.memory for m in new_order_working_memory], intent_result["trigger_retrieval"] + # Use the evaluation function to check if memories can answer the query + can_answer = self.evaluate_query_with_memories( + query=query, memory_texts=updated_memories, user_id=user_id + ) + return updated_memories, can_answer def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ @@ -140,7 +210,9 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: query = msg.content query_keywords = self.monitor.extract_query_keywords(query=query) - logger.info(f'Extract keywords "{query_keywords}" from query "{query}"') + logger.info( + f'Extracted keywords "{query_keywords}" from query "{query}" for user_id={user_id}' + ) if len(query_keywords) == 0: stripped_query = query.strip() @@ -155,21 +227,26 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: ) words = stripped_query # Default to character count - query_keywords = list(set(words[:20])) + query_keywords = list(set(words[: self.query_key_words_limit])) logger.error( - f"Keyword extraction failed for query. Using fallback keywords: {query_keywords[:10]}... (truncated)" + f"Keyword extraction failed for query '{query}' (user_id={user_id}). Using fallback keywords: {query_keywords[:10]}... (truncated)", + exc_info=True, ) item = QueryMonitorItem( + user_id=user_id, + mem_cube_id=mem_cube_id, query_text=query, keywords=query_keywords, max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, ) - self.monitor.query_monitors[user_id][mem_cube_id].put(item=item) + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.obj.put(item=item) + # Sync with database after adding new item + query_db_manager.sync_with_orm() logger.debug( - f"Queries in monitor are " - f"{self.monitor.query_monitors[user_id][mem_cube_id].get_queries_with_timesort()}." + f"Queries in monitor for user_id={user_id}, mem_cube_id={mem_cube_id}: {query_db_manager.obj.get_queries_with_timesort()}" ) queries = [msg.content for msg in messages] @@ -183,7 +260,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: top_k=self.top_k, ) logger.info( - f"Processed {queries} and get {len(new_candidates)} new candidate memories." + f"Processed {len(queries)} queries {queries} and retrieved {len(new_candidates)} new candidate memories for user_id={user_id}" ) # rerank @@ -194,7 +271,9 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: original_memory=cur_working_memory, new_memory=new_candidates, ) - logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}") + logger.info( + f"Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}" + ) # update activation memories logger.info( @@ -293,10 +372,17 @@ def process_session_turn( text_mem_base = mem_cube.text_mem if not isinstance(text_mem_base, TreeTextMemory): - logger.error("Not implemented!", exc_info=True) + logger.error( + f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} " + f"for mem_cube_id={mem_cube_id}, user_id={user_id}. " + f"text_mem_base value: {text_mem_base}", + exc_info=True, + ) return - logger.info(f"Processing {len(queries)} queries.") + logger.info( + f"Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}" + ) cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] @@ -312,16 +398,20 @@ def process_session_turn( time_trigger_flag = True if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag): - logger.info(f"Query schedule not triggered. Intent_result: {intent_result}") + logger.info( + f"Query schedule not triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. Intent_result: {intent_result}" + ) return elif (not intent_result["trigger_retrieval"]) and time_trigger_flag: - logger.info("Query schedule is forced to trigger due to time ticker") + logger.info( + f"Query schedule forced to trigger due to time ticker for user_id={user_id}, mem_cube_id={mem_cube_id}" + ) intent_result["trigger_retrieval"] = True intent_result["missing_evidences"] = queries else: logger.info( - f'Query schedule triggered for user "{user_id}" and mem_cube "{mem_cube_id}".' - f" Missing evidences: {intent_result['missing_evidences']}" + f"Query schedule triggered for user_id={user_id}, mem_cube_id={mem_cube_id}. " + f"Missing evidences: {intent_result['missing_evidences']}" ) missing_evidences = intent_result["missing_evidences"] @@ -329,7 +419,9 @@ def process_session_turn( k_per_evidence = max(1, top_k // max(1, num_evidence)) new_candidates = [] for item in missing_evidences: - logger.info(f"missing_evidences: {item}") + logger.info( + f"Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}" + ) info = { "user_id": user_id, "session_id": "", @@ -343,7 +435,7 @@ def process_session_turn( info=info, ) logger.info( - f"search results for {missing_evidences}: {[one.memory for one in results]}" + f"Search results for missing evidence '{item}': {[one.memory for one in results]}" ) new_candidates.extend(results) return cur_working_memory, new_candidates diff --git a/src/memos/mem_scheduler/memory_manage_modules/__init__.py b/src/memos/mem_scheduler/memory_manage_modules/__init__.py new file mode 100644 index 00000000..94d70429 --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/__init__.py @@ -0,0 +1,5 @@ +from .memory_filter import MemoryFilter +from .retriever import SchedulerRetriever + + +__all__ = ["MemoryFilter", "SchedulerRetriever"] diff --git a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py new file mode 100644 index 00000000..e18c6e51 --- /dev/null +++ b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py @@ -0,0 +1,308 @@ +from memos.configs.mem_scheduler import BaseSchedulerConfig +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.utils.misc_utils import extract_json_dict +from memos.memories.textual.tree import TextualMemoryItem + + +logger = get_logger(__name__) + + +class MemoryFilter(BaseSchedulerModule): + def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): + super().__init__() + self.config: BaseSchedulerConfig = config + self.process_llm = process_llm + + def filter_unrelated_memories( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> (list[TextualMemoryItem], bool): + """ + Filter out memories that are completely unrelated to the query history using LLM. + + Args: + query_history: List of query strings to determine relevance + memories: List of TextualMemoryItem objects to be filtered + + Returns: + Tuple of (filtered_memories, success_flag) + - filtered_memories: List of TextualMemoryItem objects that are relevant to queries + - success_flag: Boolean indicating if LLM filtering was successful + + Note: + If LLM filtering fails, returns all memories (conservative approach) + """ + success_flag = False + + if not memories: + logger.info("No memories to filter - returning empty list") + return [], True + + if not query_history: + logger.info("No query history provided - keeping all memories") + return memories, True + + logger.info( + f"Starting memory filtering for {len(memories)} memories against {len(query_history)} queries" + ) + + # Extract memory texts for LLM processing + memory_texts = [mem.memory for mem in memories] + + # Build LLM prompt for memory filtering + prompt = self.build_prompt( + "memory_filtering", + query_history=[f"[{i}] {query}" for i, query in enumerate(query_history)], + memories=[f"[{i}] {mem}" for i, mem in enumerate(memory_texts)], + ) + logger.debug(f"Generated filtering prompt: {prompt[:200]}...") # Log first 200 chars + + # Get LLM response + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug(f"Received LLM filtering response: {response[:200]}...") # Log first 200 chars + + try: + # Parse JSON response + response = extract_json_dict(response) + logger.debug(f"Parsed JSON response: {response}") + relevant_indices = response["relevant_memories"] + filtered_count = response["filtered_count"] + reasoning = response["reasoning"] + + # Validate indices + if not isinstance(relevant_indices, list): + raise ValueError("relevant_memories must be a list") + + # Filter memories based on relevant indices + filtered_memories = [] + for idx in relevant_indices: + if isinstance(idx, int) and 0 <= idx < len(memories): + filtered_memories.append(memories[idx]) + else: + logger.warning(f"Invalid memory index {idx} - skipping") + + logger.info( + f"Successfully filtered memories. Kept {len(filtered_memories)} out of {len(memories)} memories. " + f"Filtered out {filtered_count} unrelated memories. " + f"Filtering reasoning: {reasoning}" + ) + success_flag = True + + except Exception as e: + logger.error( + f"Failed to filter memories with LLM. Exception: {e}. Raw response: {response}", + exc_info=True, + ) + # Conservative approach: keep all memories if filtering fails + filtered_memories = memories + success_flag = False + + return filtered_memories, success_flag + + def filter_redundant_memories( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> (list[TextualMemoryItem], bool): + """ + Filter out redundant memories using LLM analysis. + + This function removes redundant memories by keeping the most informative + version when multiple memories contain similar information relevant to queries. + + Args: + query_history: List of query strings to determine relevance and value + memories: List of TextualMemoryItem objects to be filtered + + Returns: + Tuple of (filtered_memories, success_flag) + - filtered_memories: List of TextualMemoryItem objects after redundancy filtering + - success_flag: Boolean indicating if LLM filtering was successful + + Note: + If LLM filtering fails, returns all memories (conservative approach) + """ + success_flag = False + + if not memories: + logger.info("No memories to filter for redundancy - returning empty list") + return [], True + + if not query_history: + logger.info("No query history provided - keeping all memories") + return memories, True + + if len(memories) <= 1: + logger.info("Only one memory - no redundancy to filter") + return memories, True + + logger.info( + f"Starting redundancy filtering for {len(memories)} memories against {len(query_history)} queries" + ) + + # Extract memory texts for LLM processing + memory_texts = [mem.memory for mem in memories] + + # Build LLM prompt for redundancy filtering + prompt = self.build_prompt( + "memory_redundancy_filtering", + query_history=[f"[{i}] {query}" for i, query in enumerate(query_history)], + memories=[f"[{i}] {mem}" for i, mem in enumerate(memory_texts)], + ) + logger.debug( + f"Generated redundancy filtering prompt: {prompt[:200]}..." + ) # Log first 200 chars + + # Get LLM response + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug( + f"Received LLM redundancy filtering response: {response[:200]}..." + ) # Log first 200 chars + + try: + # Parse JSON response + response = extract_json_dict(response) + logger.debug(f"Parsed JSON response: {response}") + kept_indices = response["kept_memories"] + redundant_groups = response.get("redundant_groups", []) + reasoning = response["reasoning"] + + # Validate indices + if not isinstance(kept_indices, list): + raise ValueError("kept_memories must be a list") + + # Filter memories based on kept indices + filtered_memories = [] + for idx in kept_indices: + if isinstance(idx, int) and 0 <= idx < len(memories): + filtered_memories.append(memories[idx]) + else: + logger.warning(f"Invalid memory index {idx} - skipping") + + logger.info( + f"Successfully filtered redundant memories. " + f"Kept {len(filtered_memories)} out of {len(memories)} memories. " + f"Removed {len(memories) - len(filtered_memories)} redundant memories. " + f"Redundant groups identified: {len(redundant_groups)}. " + f"Filtering reasoning: {reasoning}" + ) + success_flag = True + + except Exception as e: + logger.error( + f"Failed to filter redundant memories with LLM. Exception: {e}. Raw response: {response}", + exc_info=True, + ) + # Conservative approach: keep all memories if filtering fails + filtered_memories = memories + success_flag = False + + return filtered_memories, success_flag + + def filter_unrelated_and_redundant_memories( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> (list[TextualMemoryItem], bool): + """ + Filter out both unrelated and redundant memories using LLM analysis. + + This function performs two types of filtering in sequence: + 1. Remove memories that are completely unrelated to the query history + 2. Remove redundant memories by keeping the most informative version + + Args: + query_history: List of query strings to determine relevance and value + memories: List of TextualMemoryItem objects to be filtered + + Returns: + Tuple of (filtered_memories, success_flag) + - filtered_memories: List of TextualMemoryItem objects after both filtering steps + - success_flag: Boolean indicating if LLM filtering was successful + + Note: + If LLM filtering fails, returns all memories (conservative approach) + """ + success_flag = False + + if not memories: + logger.info("No memories to filter for unrelated and redundant - returning empty list") + return [], True + + if not query_history: + logger.info("No query history provided - keeping all memories") + return memories, True + + if len(memories) <= 1: + logger.info("Only one memory - no filtering needed") + return memories, True + + logger.info( + f"Starting combined unrelated and redundant filtering for {len(memories)} memories against {len(query_history)} queries" + ) + + # Extract memory texts for LLM processing + memory_texts = [mem.memory for mem in memories] + + # Build LLM prompt for combined filtering + prompt = self.build_prompt( + "memory_combined_filtering", + query_history=[f"[{i}] {query}" for i, query in enumerate(query_history)], + memories=[f"[{i}] {mem}" for i, mem in enumerate(memory_texts)], + ) + logger.debug( + f"Generated combined filtering prompt: {prompt[:200]}..." + ) # Log first 200 chars + + # Get LLM response + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug( + f"Received LLM combined filtering response: {response[:200]}..." + ) # Log first 200 chars + + try: + # Parse JSON response + response = extract_json_dict(response) + logger.debug(f"Parsed JSON response: {response}") + kept_indices = response["kept_memories"] + unrelated_removed_count = response.get("unrelated_removed_count", 0) + redundant_removed_count = response.get("redundant_removed_count", 0) + redundant_groups = response.get("redundant_groups", []) + reasoning = response["reasoning"] + + # Validate indices + if not isinstance(kept_indices, list): + raise ValueError("kept_memories must be a list") + + # Filter memories based on kept indices + filtered_memories = [] + for idx in kept_indices: + if isinstance(idx, int) and 0 <= idx < len(memories): + filtered_memories.append(memories[idx]) + else: + logger.warning(f"Invalid memory index {idx} - skipping") + + logger.info( + f"Successfully filtered unrelated and redundant memories. " + f"Kept {len(filtered_memories)} out of {len(memories)} memories. " + f"Removed {len(memories) - len(filtered_memories)} memories total. " + f"Unrelated removed: {unrelated_removed_count}. " + f"Redundant removed: {redundant_removed_count}. " + f"Redundant groups identified: {len(redundant_groups)}. " + f"Filtering reasoning: {reasoning}" + ) + success_flag = True + + except Exception as e: + logger.error( + f"Failed to filter unrelated and redundant memories with LLM. Exception: {e}. Raw response: {response}", + exc_info=True, + ) + # Conservative approach: keep all memories if filtering fails + filtered_memories = memories + success_flag = False + + return filtered_memories, success_flag diff --git a/src/memos/mem_scheduler/general_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py similarity index 86% rename from src/memos/mem_scheduler/general_modules/retriever.py rename to src/memos/mem_scheduler/memory_manage_modules/retriever.py index 3732078d..c650e952 100644 --- a/src/memos/mem_scheduler/general_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -8,8 +8,8 @@ TreeTextMemory_SEARCH_METHOD, ) from memos.mem_scheduler.utils.filter_utils import ( - filter_similar_memories, filter_too_short_memories, + filter_vector_based_similar_memories, transform_name_to_key, ) from memos.mem_scheduler.utils.misc_utils import ( @@ -17,6 +17,8 @@ ) from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from .memory_filter import MemoryFilter + logger = get_logger(__name__) @@ -32,6 +34,9 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): self.config: BaseSchedulerConfig = config self.process_llm = process_llm + # Initialize memory filter + self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) + def search( self, query: str, @@ -163,7 +168,7 @@ def process_and_rerank_memories( combined_text_memory = [m.memory for m in combined_memory] # Apply similarity filter to remove overly similar memories - filtered_combined_text_memory = filter_similar_memories( + filtered_combined_text_memory = filter_vector_based_similar_memories( text_memories=combined_text_memory, similarity_threshold=self.filter_similarity_threshold, ) @@ -197,3 +202,29 @@ def process_and_rerank_memories( ) return memories_with_new_order, success_flag + + def filter_unrelated_memories( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> (list[TextualMemoryItem], bool): + return self.memory_filter.filter_unrelated_memories(query_history, memories) + + def filter_redundant_memories( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> (list[TextualMemoryItem], bool): + return self.memory_filter.filter_redundant_memories(query_history, memories) + + def filter_unrelated_and_redundant_memories( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> (list[TextualMemoryItem], bool): + """ + Filter out both unrelated and redundant memories using LLM analysis. + + This method delegates to the MemoryFilter class. + """ + return self.memory_filter.filter_unrelated_and_redundant_memories(query_history, memories) diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 229e9c3a..7fabe15a 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -21,7 +21,7 @@ def __init__(self, config: BaseSchedulerConfig): super().__init__() self.config: BaseSchedulerConfig = config - self.check_interval = self.config.get("dispatcher_monitor_check_interval", 60) + self.check_interval = self.config.get("dispatcher_monitor_check_interval", 300) self.max_failures = self.config.get("dispatcher_monitor_max_failures", 2) # Registry of monitored thread pools @@ -177,10 +177,11 @@ def _check_pools_health(self) -> None: else: pool_info["failure_count"] += 1 pool_info["healthy"] = False - logger.warning( - f"Pool '{name}' unhealthy ({pool_info['failure_count']}/{self.max_failures}): {reason}" + logger.info( + f"Pool '{name}' unhealthy ({pool_info['failure_count']}/{self.max_failures}): {reason}." + f" Note: This status does not necessarily indicate a problem with the pool itself - " + f"it may also be considered unhealthy if no tasks have been scheduled for an extended period" ) - if ( pool_info["failure_count"] >= self.max_failures and pool_info["restart"] @@ -236,7 +237,7 @@ def _restart_pool(self, name: str, pool_info: dict) -> None: return self._restart_in_progress = True - logger.warning(f"Attempting to restart thread pool '{name}'") + logger.info(f"Attempting to restart thread pool '{name}'") try: old_executor = pool_info["executor"] diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index 6bc796cc..87d99654 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -2,11 +2,18 @@ from threading import Lock from typing import Any +from sqlalchemy.engine import Engine + from memos.configs.mem_scheduler import BaseSchedulerConfig from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager +from memos.mem_scheduler.orm_modules.monitor_models import ( + DBManagerForMemoryMonitorManager, + DBManagerForQueryMonitorQueue, +) from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT, DEFAULT_WEIGHT_VECTOR_FOR_RANKING, @@ -19,7 +26,6 @@ from memos.mem_scheduler.schemas.monitor_schemas import ( MemoryMonitorItem, MemoryMonitorManager, - QueryMonitorItem, QueryMonitorQueue, ) from memos.mem_scheduler.utils.misc_utils import extract_json_dict @@ -32,7 +38,9 @@ class SchedulerGeneralMonitor(BaseSchedulerModule): """Monitors and manages scheduling operations with LLM integration.""" - def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): + def __init__( + self, process_llm: BaseLLM, config: BaseSchedulerConfig, db_engine: Engine | None = None + ): super().__init__() # hyper-parameters @@ -49,12 +57,22 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): "activation_mem_monitor_capacity", DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT ) - # attributes - # recording query_messages - self.query_monitors: dict[UserID, dict[MemCubeID, QueryMonitorQueue[QueryMonitorItem]]] = {} + # ORM-based monitor managers + self.db_engine = db_engine + if self.db_engine is None: + logger.warning( + "No database engine provided; falling back to default temporary SQLite engine. " + "This is intended for testing only. Consider providing a configured engine for production use." + ) + self.db_engine = BaseDBManager.create_default_engine() - self.working_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {} - self.activation_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {} + self.query_monitors: dict[UserID, dict[MemCubeID, DBManagerForQueryMonitorQueue]] = {} + self.working_memory_monitors: dict[ + UserID, dict[MemCubeID, DBManagerForMemoryMonitorManager] + ] = {} + self.activation_memory_monitors: dict[ + UserID, dict[MemCubeID, DBManagerForMemoryMonitorManager] + ] = {} # Lifecycle monitor self.last_activation_mem_update_time = datetime.min @@ -96,40 +114,47 @@ def register_query_monitor_if_not_exists( if user_id not in self.query_monitors: self.query_monitors[user_id] = {} if mem_cube_id not in self.query_monitors[user_id]: - self.query_monitors[user_id][mem_cube_id] = QueryMonitorQueue( - maxsize=self.config.context_window_size - ) + if self.db_engine: + # Create ORM manager with initial QueryMonitorQueue + initial_queue = QueryMonitorQueue(maxsize=self.config.context_window_size) + db_manager = DBManagerForQueryMonitorQueue( + engine=self.db_engine, + user_id=str(user_id), + mem_cube_id=str(mem_cube_id), + obj=initial_queue, + ) + self.query_monitors[user_id][mem_cube_id] = db_manager + else: + # Fallback to in-memory (this shouldn't happen with proper config) + logger.warning("ORM persistence disabled, using in-memory fallback") + # For backward compatibility, we'll need to handle this case differently + raise RuntimeError("ORM persistence is required but not properly configured") def register_memory_manager_if_not_exists( self, user_id: UserID | str, mem_cube_id: MemCubeID | str, - memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]], + memory_monitors: dict[UserID, dict[MemCubeID, DBManagerForMemoryMonitorManager]], max_capacity: int, ) -> None: """ - Register a new MemoryMonitorManager for the given user and memory cube if it doesn't exist. + Register a new MemoryMonitorManager ORM manager for the given user and memory cube if it doesn't exist. Thread-safe implementation using double-checked locking pattern. - Checks if a MemoryMonitorManager already exists for the specified user_id and mem_cube_id. - If not, creates a new MemoryMonitorManager with appropriate capacity settings and registers it. + Checks if a MemoryMonitorManager ORM manager already exists for the specified user_id and mem_cube_id. + If not, creates a new ORM manager with appropriate capacity settings and registers it. Args: user_id: The ID of the user to associate with the memory manager mem_cube_id: The ID of the memory cube to monitor - memory_monitors: Dictionary storing existing memory monitor managers + memory_monitors: Dictionary storing existing memory monitor ORM managers max_capacity: Maximum capacity for the new memory monitor manager - lock: Threading lock to ensure safe concurrent access - - Note: - This function will update the loose_max_working_memory_capacity based on the current - WorkingMemory size plus partial retention number before creating a new manager. """ # First check (lock-free, fast path) # Quickly verify existence without lock overhead if user_id in memory_monitors and mem_cube_id in memory_monitors[user_id]: logger.info( - f"MemoryMonitorManager already exists for user_id={user_id}, " + f"MemoryMonitorManager ORM manager already exists for user_id={user_id}, " f"mem_cube_id={mem_cube_id} in the provided memory_monitors dictionary" ) return @@ -140,22 +165,33 @@ def register_memory_manager_if_not_exists( # Re-check after acquiring lock, as another thread might have created it if user_id in memory_monitors and mem_cube_id in memory_monitors[user_id]: logger.info( - f"MemoryMonitorManager already exists for user_id={user_id}, " + f"MemoryMonitorManager ORM manager already exists for user_id={user_id}, " f"mem_cube_id={mem_cube_id} in the provided memory_monitors dictionary" ) return - # Initialize MemoryMonitorManager with user ID, memory cube ID, and max capacity - monitor_manager = MemoryMonitorManager( - user_id=user_id, mem_cube_id=mem_cube_id, max_capacity=max_capacity - ) + if self.db_engine: + # Initialize MemoryMonitorManager with user ID, memory cube ID, and max capacity + monitor_manager = MemoryMonitorManager( + user_id=user_id, mem_cube_id=mem_cube_id, max_capacity=max_capacity + ) - # Safely register the new manager in the nested dictionary structure - memory_monitors.setdefault(user_id, {})[mem_cube_id] = monitor_manager - logger.info( - f"Registered new MemoryMonitorManager for user_id={user_id}," - f" mem_cube_id={mem_cube_id} with max_capacity={max_capacity}" - ) + # Create ORM manager + db_manager = DBManagerForMemoryMonitorManager( + engine=self.db_engine, + user_id=str(user_id), + mem_cube_id=str(mem_cube_id), + obj=monitor_manager, + ) + + # Safely register the new ORM manager in the nested dictionary structure + memory_monitors.setdefault(user_id, {})[mem_cube_id] = db_manager + logger.info( + f"Registered new MemoryMonitorManager ORM manager for user_id={user_id}," + f" mem_cube_id={mem_cube_id} with max_capacity={max_capacity}" + ) + else: + raise RuntimeError("ORM persistence is required but not properly configured") def update_working_memory_monitors( self, @@ -182,10 +218,14 @@ def update_working_memory_monitors( max_capacity=self.working_mem_monitor_capacity, ) - self.working_memory_monitors[user_id][mem_cube_id].update_memories( + # Get the ORM manager and update memories with database sync + db_manager = self.working_memory_monitors[user_id][mem_cube_id] + db_manager.obj.update_memories( new_memory_monitors=new_working_memory_monitors, partial_retention_number=self.partial_retention_number, ) + # Sync with database + db_manager.sync_with_orm(size_limit=self.working_mem_monitor_capacity) def update_activation_memory_monitors( self, user_id: str, mem_cube_id: str, mem_cube: GeneralMemCube @@ -199,17 +239,21 @@ def update_activation_memory_monitors( # === update activation memory monitors === # Sort by importance_score in descending order and take top k + working_db_manager = self.working_memory_monitors[user_id][mem_cube_id] top_k_memories = sorted( - self.working_memory_monitors[user_id][mem_cube_id].memories, + working_db_manager.obj.memories, key=lambda m: m.get_importance_score(weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING), reverse=True, )[: self.activation_mem_monitor_capacity] # Update the activation memory monitors with these important memories - self.activation_memory_monitors[user_id][mem_cube_id].update_memories( + activation_db_manager = self.activation_memory_monitors[user_id][mem_cube_id] + activation_db_manager.obj.update_memories( new_memory_monitors=top_k_memories, partial_retention_number=self.partial_retention_number, ) + # Sync with database + activation_db_manager.sync_with_orm(size_limit=self.activation_mem_monitor_capacity) def timed_trigger(self, last_time: datetime, interval_seconds: float) -> bool: now = datetime.utcnow() @@ -255,9 +299,12 @@ def get_monitor_memories( ) return [] - manager: MemoryMonitorManager = monitor_dict[user_id][mem_cube_id] + db_manager: DBManagerForMemoryMonitorManager = monitor_dict[user_id][mem_cube_id] + # Load latest data from database before accessing + db_manager.sync_with_orm() + # Sort memories by recording_count in descending order and return top_k items - sorted_memory_monitors = manager.get_sorted_mem_monitors(reverse=True) + sorted_memory_monitors = db_manager.obj.get_sorted_mem_monitors(reverse=True) sorted_text_memories = [m.memory_text for m in sorted_memory_monitors[:top_k]] return sorted_text_memories @@ -273,16 +320,19 @@ def get_monitors_info(self, user_id: str, mem_cube_id: str) -> dict[str, Any]: return {} info_dict = {} - for manager in [ + for db_manager in [ self.working_memory_monitors[user_id][mem_cube_id], self.activation_memory_monitors[user_id][mem_cube_id], ]: + # Sync with database to get latest data + db_manager.sync_with_orm() + manager = db_manager.obj info_dict[str(type(manager))] = { "user_id": user_id, "mem_cube_id": mem_cube_id, "memory_count": manager.memory_size, "max_capacity": manager.max_capacity, - "top_memories": self.get_scheduler_working_memories(user_id, mem_cube_id, top_k=1), + "top_memories": self.get_monitor_memories(user_id, mem_cube_id, top_k=1), } return info_dict @@ -308,3 +358,33 @@ def detect_intent( logger.error(f"Fail to extract json dict from response: {response}") response = {"trigger_retrieval": False, "missing_evidences": q_list} return response + + def close(self): + """Close all database connections and clean up resources""" + logger.info("Closing database connections for all monitors") + + # Close all query monitor database managers + for user_monitors in self.query_monitors.values(): + for db_manager in user_monitors.values(): + try: + db_manager.close() + except Exception as e: + logger.error(f"Error closing query monitor DB manager: {e}") + + # Close all working memory monitor database managers + for user_monitors in self.working_memory_monitors.values(): + for db_manager in user_monitors.values(): + try: + db_manager.close() + except Exception as e: + logger.error(f"Error closing working memory monitor DB manager: {e}") + + # Close all activation memory monitor database managers + for user_monitors in self.activation_memory_monitors.values(): + for db_manager in user_monitors.values(): + try: + db_manager.close() + except Exception as e: + logger.error(f"Error closing activation memory monitor DB manager: {e}") + + logger.info("All database connections closed") diff --git a/src/memos/mem_scheduler/mos_for_test_scheduler.py b/src/memos/mem_scheduler/mos_for_test_scheduler.py deleted file mode 100644 index f275da2b..00000000 --- a/src/memos/mem_scheduler/mos_for_test_scheduler.py +++ /dev/null @@ -1,146 +0,0 @@ -from datetime import datetime - -from memos.configs.mem_os import MOSConfig -from memos.log import get_logger -from memos.mem_os.main import MOS -from memos.mem_scheduler.schemas.general_schemas import ( - ANSWER_LABEL, - MONITOR_WORKING_MEMORY_TYPE, - QUERY_LABEL, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem - - -logger = get_logger(__name__) - - -class MOSForTestScheduler(MOS): - """This class is only to test abilities of mem scheduler""" - - def __init__(self, config: MOSConfig): - super().__init__(config) - - def _str_memories(self, memories: list[str]) -> str: - """Format memories for display.""" - if not memories: - return "No memories." - return "\n".join(f"{i + 1}. {memory}" for i, memory in enumerate(memories)) - - def chat(self, query: str, user_id: str | None = None) -> str: - """ - Chat with the MOS. - - Args: - query (str): The user's query. - - Returns: - str: The response from the MOS. - """ - target_user_id = user_id if user_id is not None else self.user_id - accessible_cubes = self.user_manager.get_user_cubes(target_user_id) - user_cube_ids = [cube.cube_id for cube in accessible_cubes] - if target_user_id not in self.chat_history_manager: - self._register_chat_history(target_user_id) - - chat_history = self.chat_history_manager[target_user_id] - - topk_for_scheduler = 2 - - if self.config.enable_textual_memory and self.mem_cubes: - memories_all = [] - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube_id not in user_cube_ids: - continue - if not mem_cube.text_mem: - continue - - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=QUERY_LABEL, - content=query, - timestamp=datetime.now(), - ) - cur_working_memories = [m.memory for m in mem_cube.text_mem.get_working_memory()] - print(f"Working memories before schedule: {cur_working_memories}") - - # --- force to run mem_scheduler --- - self.mem_scheduler.monitor.query_trigger_interval = 0 - self.mem_scheduler._query_message_consumer(messages=[message_item]) - - # from scheduler - scheduler_memories = self.mem_scheduler.monitor.get_monitor_memories( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - memory_type=MONITOR_WORKING_MEMORY_TYPE, - top_k=topk_for_scheduler, - ) - print(f"Working memories after schedule: {scheduler_memories}") - memories_all.extend(scheduler_memories) - - # from mem_cube - memories = mem_cube.text_mem.search( - query, - top_k=self.config.top_k - topk_for_scheduler, - info={ - "user_id": target_user_id, - "session_id": self.session_id, - "chat_history": chat_history.chat_history, - }, - ) - text_memories = [m.memory for m in memories] - print(f"Search results with new working memories: {text_memories}") - memories_all.extend(text_memories) - - memories_all = list(set(memories_all)) - - logger.info(f"๐Ÿง  [Memory] Searched memories:\n{self._str_memories(memories_all)}\n") - system_prompt = self._build_system_prompt(memories_all) - else: - system_prompt = self._build_system_prompt() - current_messages = [ - {"role": "system", "content": system_prompt}, - *chat_history.chat_history, - {"role": "user", "content": query}, - ] - past_key_values = None - - if self.config.enable_activation_memory: - assert self.config.chat_model.backend == "huggingface", ( - "Activation memory only used for huggingface backend." - ) - # TODO this only one cubes - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube_id not in user_cube_ids: - continue - if mem_cube.act_mem: - kv_cache = next(iter(mem_cube.act_mem.get_all()), None) - past_key_values = ( - kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None - ) - break - # Generate response - response = self.chat_llm.generate(current_messages, past_key_values=past_key_values) - else: - response = self.chat_llm.generate(current_messages) - logger.info(f"๐Ÿค– [Assistant] {response}\n") - chat_history.chat_history.append({"role": "user", "content": query}) - chat_history.chat_history.append({"role": "assistant", "content": response}) - self.chat_history_manager[user_id] = chat_history - - # submit message to scheduler - for accessible_mem_cube in accessible_cubes: - mem_cube_id = accessible_mem_cube.cube_id - mem_cube = self.mem_cubes[mem_cube_id] - if self.enable_mem_scheduler and self.mem_scheduler is not None: - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=ANSWER_LABEL, - content=response, - timestamp=datetime.now(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) - return response diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py new file mode 100644 index 00000000..dd08954a --- /dev/null +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -0,0 +1,124 @@ +from typing import TYPE_CHECKING + +from memos.configs.mem_scheduler import GeneralSchedulerConfig +from memos.log import get_logger +from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.general_scheduler import GeneralScheduler +from memos.mem_scheduler.schemas.general_schemas import ( + MemCubeID, + UserID, +) +from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory + + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem + + +logger = get_logger(__name__) + + +class OptimizedScheduler(GeneralScheduler): + """Optimized scheduler with improved working memory management""" + + def __init__(self, config: GeneralSchedulerConfig): + super().__init__(config) + + def replace_working_memory( + self, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + mem_cube: GeneralMemCube, + original_memory: list[TextualMemoryItem], + new_memory: list[TextualMemoryItem], + ) -> None | list[TextualMemoryItem]: + """Replace working memory with new memories after reranking.""" + text_mem_base = mem_cube.text_mem + if isinstance(text_mem_base, TreeTextMemory): + text_mem_base: TreeTextMemory = text_mem_base + + # process rerank memories with llm + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + # Sync with database to get latest query history + query_db_manager.sync_with_orm() + + query_history = query_db_manager.obj.get_queries_with_timesort() + memories_with_new_order, rerank_success_flag = ( + self.retriever.process_and_rerank_memories( + queries=query_history, + original_memory=original_memory, + new_memory=new_memory, + top_k=self.top_k, + ) + ) + + # Apply combined filtering (unrelated + redundant) + logger.info( + f"Applying combined unrelated and redundant memory filtering to {len(memories_with_new_order)} memories" + ) + filtered_memories, filtering_success_flag = ( + self.retriever.filter_unrelated_and_redundant_memories( + query_history=query_history, + memories=memories_with_new_order, + ) + ) + + if filtering_success_flag: + logger.info( + f"Combined filtering completed successfully. " + f"Filtered from {len(memories_with_new_order)} to {len(filtered_memories)} memories" + ) + memories_with_new_order = filtered_memories + else: + logger.warning( + "Combined filtering failed - keeping memories as fallback. " + f"Count: {len(memories_with_new_order)}" + ) + + # Update working memory monitors + query_keywords = query_db_manager.obj.get_keywords_collections() + logger.info( + f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords" + ) + new_working_memory_monitors = self.transform_working_memories_to_monitors( + query_keywords=query_keywords, + memories=memories_with_new_order, + ) + + if not rerank_success_flag: + for one in new_working_memory_monitors: + one.sorting_score = 0 + + logger.info(f"update {len(new_working_memory_monitors)} working_memory_monitors") + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=new_working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + # Use the filtered and reranked memories directly + text_mem_base.replace_working_memory(memories=memories_with_new_order) + + # Update monitor after replacing working memory + mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][ + mem_cube_id + ].obj.get_sorted_mem_monitors(reverse=True) + new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors] + + logger.info( + f"The working memory has been replaced with {len(memories_with_new_order)} new memories." + ) + self.log_working_memory_replacement( + original_memory=original_memory, + new_memory=new_working_memories, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + log_func_callback=self._submit_web_logs, + ) + else: + logger.error("memory_base is not supported") + memories_with_new_order = new_memory + + return memories_with_new_order diff --git a/src/memos/mem_scheduler/orm_modules/__init__.py b/src/memos/mem_scheduler/orm_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/memos/mem_scheduler/orm_modules/base_model.py b/src/memos/mem_scheduler/orm_modules/base_model.py new file mode 100644 index 00000000..9d75a12b --- /dev/null +++ b/src/memos/mem_scheduler/orm_modules/base_model.py @@ -0,0 +1,635 @@ +import json +import os +import tempfile +import time + +from abc import abstractmethod +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, TypeVar + +from sqlalchemy import Boolean, Column, DateTime, String, Text, and_, create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session, sessionmaker + +from memos.log import get_logger +from memos.mem_user.user_manager import UserManager + + +T = TypeVar("T") # The model type (MemoryMonitorManager, QueryMonitorManager, etc.) +ORM = TypeVar("ORM") # The ORM model type + +logger = get_logger(__name__) + +Base = declarative_base() + + +class LockableORM(Base): + """Abstract base class for lockable ORM models""" + + __abstract__ = True + + # Primary composite key + user_id = Column(String(255), primary_key=True) + mem_cube_id = Column(String(255), primary_key=True) + + # Serialized data + serialized_data = Column(Text, nullable=False) + + lock_acquired = Column(Boolean, default=False) + lock_expiry = Column(DateTime, nullable=True) + + # Version control tag (0-255, cycles back to 0) + version_control = Column(String(3), default="0") + + +class BaseDBManager(UserManager): + """Abstract base class for database managers with proper locking mechanism + + This class provides a foundation for managing database operations with + distributed locking capabilities to ensure data consistency across + multiple processes or threads. + """ + + def __init__( + self, + engine: Engine, + user_id: str | None = None, + mem_cube_id: str | None = None, + lock_timeout: int = 10, + ): + """Initialize the database manager + + Args: + engine: SQLAlchemy engine instance + user_id: Unique identifier for the user + mem_cube_id: Unique identifier for the memory cube + lock_timeout: Timeout in seconds for lock acquisition + """ + # Do not use super init func to avoid UserManager initialization + self.engine = engine + self.SessionLocal = None + self.obj = None + self.user_id = user_id + self.mem_cube_id = mem_cube_id + self.lock_timeout = lock_timeout + self.last_version_control = None # Track the last version control tag + + self.init_manager( + engine=self.engine, + user_id=self.user_id, + mem_cube_id=self.mem_cube_id, + ) + + @property + @abstractmethod + def orm_class(self) -> type[LockableORM]: + """Return the ORM model class for this manager + + Returns: + The SQLAlchemy ORM model class + """ + raise NotImplementedError() + + @property + @abstractmethod + def obj_class(self) -> Any: + """Return the business object class for this manager + + Returns: + The business logic object class + """ + raise NotImplementedError() + + def init_manager(self, engine: Engine, user_id: str, mem_cube_id: str): + """Initialize the database manager with engine and identifiers + + Args: + engine: SQLAlchemy engine instance + user_id: User identifier + mem_cube_id: Memory cube identifier + + Raises: + RuntimeError: If database initialization fails + """ + try: + self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + logger.info(f"{self.orm_class} initialized with engine {engine}") + logger.info(f"Set user_id to {user_id}; mem_cube_id to {mem_cube_id}") + + # Create tables if they don't exist + self._create_table_with_error_handling(engine) + logger.debug(f"Successfully created/verified table for {self.orm_class.__tablename__}") + + except Exception as e: + error_msg = f"Failed to initialize database manager for {self.orm_class.__name__}: {e}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def _create_table_with_error_handling(self, engine: Engine): + """Create table with proper error handling for common database conflicts + + Args: + engine: SQLAlchemy engine instance + + Raises: + RuntimeError: If table creation fails after handling known issues + """ + try: + self.orm_class.__table__.create(bind=engine, checkfirst=True) + except Exception as e: + error_str = str(e).lower() + + # Handle common SQLite index already exists error + if "index" in error_str and "already exists" in error_str: + logger.warning(f"Index already exists for {self.orm_class.__tablename__}: {e}") + # Try to create just the table without indexes + try: + # Create a temporary table definition without indexes + table_without_indexes = self.orm_class.__table__.copy() + table_without_indexes._indexes.clear() # Remove all indexes + table_without_indexes.create(bind=engine, checkfirst=True) + logger.info( + f"Created table {self.orm_class.__tablename__} without problematic indexes" + ) + except Exception as table_error: + logger.error(f"Failed to create table even without indexes: {table_error}") + raise + else: + # Re-raise other types of errors + raise + + def _get_session(self) -> Session: + """Get a database session""" + return self.SessionLocal() + + def _serialize(self, obj: T) -> str: + """Serialize the object to JSON""" + if hasattr(obj, "to_json"): + return obj.to_json() + return json.dumps(obj) + + def _deserialize(self, data: str, model_class: type[T]) -> T: + """Deserialize JSON to object""" + if hasattr(model_class, "from_json"): + return model_class.from_json(data) + return json.loads(data) + + def acquire_lock(self, block: bool = True, **kwargs) -> bool: + """Acquire a distributed lock for the current user and memory cube + + Args: + block: Whether to block until lock is acquired + **kwargs: Additional filter criteria + + Returns: + True if lock was acquired, False otherwise + """ + session = self._get_session() + + try: + now = datetime.now() + expiry = now + timedelta(seconds=self.lock_timeout) + + # Query for existing record with lock information + query = ( + session.query(self.orm_class) + .filter_by(**kwargs) + .filter( + and_( + self.orm_class.user_id == self.user_id, + self.orm_class.mem_cube_id == self.mem_cube_id, + ) + ) + ) + + record = query.first() + + # If no record exists, lock can be acquired immediately + if record is None: + logger.info( + f"No existing record found for {self.user_id}/{self.mem_cube_id}, lock can be acquired" + ) + return True + + # Check if lock is currently held and not expired + if record.lock_acquired and record.lock_expiry and now < record.lock_expiry: + if block: + # Wait for lock to be released or expire + logger.info( + f"Waiting for lock to be released for {self.user_id}/{self.mem_cube_id}" + ) + while record.lock_acquired and record.lock_expiry and now < record.lock_expiry: + time.sleep(0.1) # Small delay before retry + session.refresh(record) # Refresh record state + now = datetime.now() + else: + logger.warning( + f"Lock is held for {self.user_id}/{self.mem_cube_id}, cannot acquire" + ) + return False + + # Acquire the lock by updating the record + query.update( + { + "lock_acquired": True, + "lock_expiry": expiry, + }, + synchronize_session=False, + ) + + session.commit() + logger.info(f"Lock acquired for {self.user_id}/{self.mem_cube_id}") + return True + + except Exception as e: + session.rollback() + logger.error(f"Failed to acquire lock for {self.user_id}/{self.mem_cube_id}: {e}") + return False + finally: + session.close() + + def release_locks(self, user_id: str, mem_cube_id: str, **kwargs): + """Release locks for the specified user and memory cube + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + **kwargs: Additional filter criteria + """ + session = self._get_session() + + try: + # Update all matching records to release locks + result = ( + session.query(self.orm_class) + .filter_by(**kwargs) + .filter( + and_( + self.orm_class.user_id == user_id, self.orm_class.mem_cube_id == mem_cube_id + ) + ) + .update( + { + "lock_acquired": False, + "lock_expiry": None, # Clear expiry time as well + }, + synchronize_session=False, + ) + ) + session.commit() + logger.info(f"Lock released for {user_id}/{mem_cube_id} (affected {result} records)") + + except Exception as e: + session.rollback() + logger.error(f"Failed to release lock for {user_id}/{mem_cube_id}: {e}") + finally: + session.close() + + def _get_primary_key(self) -> dict[str, Any]: + """Get the primary key dictionary for the current instance + + Returns: + Dictionary containing user_id and mem_cube_id + """ + return {"user_id": self.user_id, "mem_cube_id": self.mem_cube_id} + + def _increment_version_control(self, current_tag: str) -> str: + """Increment the version control tag, cycling from 255 back to 0 + + Args: + current_tag: Current version control tag as string + + Returns: + Next version control tag as string + """ + try: + current_value = int(current_tag) + next_value = (current_value + 1) % 256 # Cycle from 255 back to 0 + return str(next_value) + except (ValueError, TypeError): + # If current_tag is invalid, start from 0 + logger.warning(f"Invalid version_control '{current_tag}', resetting to '0'") + return "0" + + @abstractmethod + def merge_items(self, orm_instance, obj_instance, size_limit): + """Merge items from database with current object instance + + Args: + orm_instance: ORM instance from database + obj_instance: Current business object instance + size_limit: Maximum number of items to keep after merge + """ + + def sync_with_orm(self, size_limit: int | None = None) -> None: + """ + Synchronize data between the database and the business object. + + This method performs a three-step synchronization process: + 1. Acquire lock and get existing data from database + 2. Merge database items with current object items + 3. Write merged data back to database and release lock + + Args: + size_limit: Optional maximum number of items to keep after synchronization. + If specified, only the most recent items will be retained. + """ + logger.info( + f"Starting sync_with_orm for {self.user_id}/{self.mem_cube_id} with size_limit={size_limit}" + ) + user_id = self.user_id + mem_cube_id = self.mem_cube_id + + session = self._get_session() + + try: + # Acquire lock before any database operations + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire lock for synchronization") + return + + # 1. Get existing data from database + orm_instance = ( + session.query(self.orm_class) + .filter_by(user_id=user_id, mem_cube_id=mem_cube_id) + .first() + ) + + # If no existing record, create a new one + if orm_instance is None: + if self.obj is None: + logger.warning("No object to synchronize and no existing database record") + return + + orm_instance = self.orm_class( + user_id=user_id, + mem_cube_id=mem_cube_id, + serialized_data=self.obj.to_json(), + version_control="0", # Start with tag 0 for new records + ) + logger.info( + "No existing ORM instance found. Created a new one. " + "Note: size_limit was not applied because there is no existing data to merge." + ) + session.add(orm_instance) + session.commit() + # Update last_version_control for new record + self.last_version_control = "0" + return + + # 2. Check version control and merge data from database with current object + if self.obj is not None: + current_db_tag = orm_instance.version_control + new_tag = self._increment_version_control(current_db_tag) + # Check if this is the first sync (last_version_control is None) + if self.last_version_control is None: + # First sync, increment version and perform merge + logger.info( + f"First sync, incrementing version from {current_db_tag} to {new_tag} for {self.user_id}/{self.mem_cube_id}" + ) + elif current_db_tag == self.last_version_control: + logger.info( + f"Version control unchanged ({current_db_tag}), directly update {self.user_id}/{self.mem_cube_id}" + ) + else: + # Version control has changed, increment it and perform merge + logger.info( + f"Version control changed from {self.last_version_control} to {current_db_tag}, incrementing to {new_tag} for {self.user_id}/{self.mem_cube_id}" + ) + try: + self.merge_items( + orm_instance=orm_instance, obj_instance=self.obj, size_limit=size_limit + ) + except Exception as merge_error: + logger.error(f"Error during merge_items: {merge_error}", exc_info=True) + logger.warning("Continuing with current object data without merge") + + # 3. Write merged data back to database + orm_instance.serialized_data = self.obj.to_json() + orm_instance.version_control = new_tag + logger.info(f"Updated serialized_data for {self.user_id}/{self.mem_cube_id}") + + # Update last_version_control to current value + self.last_version_control = orm_instance.version_control + else: + logger.warning("No current object to merge with database data") + + session.commit() + logger.info(f"Synchronization completed for {self.user_id}/{self.mem_cube_id}") + + except Exception as e: + session.rollback() + logger.error( + f"Error during synchronization for {user_id}/{mem_cube_id}: {e}", exc_info=True + ) + finally: + # Always release locks and close session + self.release_locks(user_id=user_id, mem_cube_id=mem_cube_id) + session.close() + + def save_to_db(self, obj_instance) -> None: + """Save the current state of the business object to the database + + Args: + obj_instance: The business object instance to save + """ + user_id = self.user_id + mem_cube_id = self.mem_cube_id + + session = self._get_session() + + try: + # Acquire lock before database operations + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire lock for saving to database") + return + + # Check if record already exists + orm_instance = ( + session.query(self.orm_class) + .filter_by(user_id=user_id, mem_cube_id=mem_cube_id) + .first() + ) + + if orm_instance is None: + # Create new record + orm_instance = self.orm_class( + user_id=user_id, + mem_cube_id=mem_cube_id, + serialized_data=obj_instance.to_json(), + version_control="0", # Start with version 0 for new records + ) + session.add(orm_instance) + logger.info(f"Created new database record for {user_id}/{mem_cube_id}") + # Update last_version_control for new record + self.last_version_control = "0" + else: + # Update existing record with version control + current_version = orm_instance.version_control + new_version = self._increment_version_control(current_version) + orm_instance.serialized_data = obj_instance.to_json() + orm_instance.version_control = new_version + logger.info( + f"Updated existing database record for {user_id}/{mem_cube_id} with version {new_version}" + ) + # Update last_version_control + self.last_version_control = new_version + + session.commit() + + except Exception as e: + session.rollback() + logger.error(f"Error saving to database for {user_id}/{mem_cube_id}: {e}") + finally: + # Always release locks and close session + self.release_locks(user_id=user_id, mem_cube_id=mem_cube_id) + session.close() + + def load_from_db(self, acquire_lock: bool = False): + """Load the business object from the database + + Args: + acquire_lock: Whether to acquire a lock during the load operation + + Returns: + The deserialized business object instance, or None if not found + """ + user_id = self.user_id + mem_cube_id = self.mem_cube_id + + session = self._get_session() + + try: + if acquire_lock: + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire lock for loading from database") + return None + + # Query for the database record + orm_instance = ( + session.query(self.orm_class) + .filter_by(user_id=user_id, mem_cube_id=mem_cube_id) + .first() + ) + + if orm_instance is None: + logger.info(f"No database record found for {user_id}/{mem_cube_id}") + return None + + # Deserialize the business object from JSON + db_instance = self.obj_class.from_json(orm_instance.serialized_data) + # Update last_version_control to track the loaded version + self.last_version_control = orm_instance.version_control + logger.info( + f"Successfully loaded object from database for {user_id}/{mem_cube_id} with version {orm_instance.version_control}" + ) + + return db_instance + + except Exception as e: + logger.error(f"Error loading from database for {user_id}/{mem_cube_id}: {e}") + return None + finally: + if acquire_lock: + self.release_locks(user_id=user_id, mem_cube_id=mem_cube_id) + session.close() + + def close(self): + """Close the database manager and clean up resources + + This method releases any held locks and disposes of the database engine. + Should be called when the manager is no longer needed. + """ + try: + # Release any locks held by this manager instance + if self.user_id and self.mem_cube_id: + self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id) + logger.info(f"Released locks for {self.user_id}/{self.mem_cube_id}") + + # Dispose of the engine to close all connections + if self.engine: + self.engine.dispose() + logger.info("Database engine disposed") + + except Exception as e: + logger.error(f"Error during close operation: {e}") + + @staticmethod + def create_default_engine() -> Engine: + """Create SQLAlchemy engine with default database path + + Returns: + SQLAlchemy Engine instance using default scheduler_orm.db + """ + temp_dir = tempfile.mkdtemp() + db_path = os.path.join(temp_dir, "test_scheduler_orm.db") + + # Clean up any existing file (though unlikely) + if os.path.exists(db_path): + os.remove(db_path) + # Remove the temp directory if still exists (should be empty) + if os.path.exists(temp_dir) and not os.listdir(temp_dir): + os.rmdir(temp_dir) + + # Ensure parent directory exists (re-create in case rmdir removed it) + parent_dir = Path(db_path).parent + parent_dir.mkdir(parents=True, exist_ok=True) + + # Log the creation of the default engine with database path + logger.info( + "Creating default SQLAlchemy engine with temporary SQLite database at: %s", db_path + ) + + return create_engine(f"sqlite:///{db_path}", echo=False) + + @staticmethod + def create_engine_from_db_path(db_path: str) -> Engine: + """Create SQLAlchemy engine from database path + + Args: + db_path: Path to database file + + Returns: + SQLAlchemy Engine instance + """ + # Ensure the directory exists + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + + return create_engine(f"sqlite:///{db_path}", echo=False) + + @staticmethod + def create_mysql_db_path( + host: str = "localhost", + port: int = 3306, + username: str = "root", + password: str = "", + database: str = "scheduler_orm", + charset: str = "utf8mb4", + ) -> str: + """Create MySQL database connection URL + + Args: + host: MySQL server hostname + port: MySQL server port + username: Database username + password: Database password (optional) + database: Database name + charset: Character set encoding + + Returns: + MySQL connection URL string + """ + # Build MySQL connection URL with proper formatting + if password: + db_path = ( + f"mysql+pymysql://{username}:{password}@{host}:{port}/{database}?charset={charset}" + ) + else: + db_path = f"mysql+pymysql://{username}@{host}:{port}/{database}?charset={charset}" + return db_path diff --git a/src/memos/mem_scheduler/orm_modules/monitor_models.py b/src/memos/mem_scheduler/orm_modules/monitor_models.py new file mode 100644 index 00000000..a5a04eb4 --- /dev/null +++ b/src/memos/mem_scheduler/orm_modules/monitor_models.py @@ -0,0 +1,261 @@ +from typing import TypeVar + +from sqlalchemy import Index +from sqlalchemy.engine import Engine + +from memos.log import get_logger +from memos.mem_scheduler.schemas.monitor_schemas import ( + MemoryMonitorItem, + MemoryMonitorManager, + QueryMonitorItem, + QueryMonitorQueue, +) + +from .base_model import BaseDBManager, LockableORM + + +logger = get_logger(__name__) + +# Type variables for generic type hints +T = TypeVar("T") # The model type (MemoryMonitorManager, QueryMonitorManager, etc.) +ORM = TypeVar("ORM") # The ORM model type + + +class MemoryMonitorManagerORM(LockableORM): + """ORM model for MemoryMonitorManager persistence + + This table stores serialized MemoryMonitorManager instances with + proper indexing for efficient user and memory cube lookups. + """ + + __tablename__ = "memory_monitor_manager" + + # Database indexes for performance optimization + __table_args__ = (Index("idx_memory_monitor_user_memcube", "user_id", "mem_cube_id"),) + + +class QueryMonitorQueueORM(LockableORM): + """ORM model for QueryMonitorQueue persistence + + This table stores serialized QueryMonitorQueue instances with + proper indexing for efficient user and memory cube lookups. + """ + + __tablename__ = "query_monitor_queue" + + # Database indexes for performance optimization + __table_args__ = (Index("idx_query_monitor_user_memcube", "user_id", "mem_cube_id"),) + + +class DBManagerForMemoryMonitorManager(BaseDBManager): + """Database manager for MemoryMonitorManager objects + + This class handles persistence, synchronization, and locking + for MemoryMonitorManager instances in the database. + """ + + def __init__( + self, + engine: Engine, + user_id: str | None = None, + mem_cube_id: str | None = None, + obj: MemoryMonitorManager | None = None, + lock_timeout: int = 10, + ): + """ + Initialize the MemoryMonitorManager database manager. + + Args: + engine: SQLAlchemy engine instance + user_id: Unique identifier for the user + mem_cube_id: Unique identifier for the memory cube + obj: Optional MemoryMonitorManager instance to manage + lock_timeout: Timeout in seconds for lock acquisition + """ + super().__init__( + engine=engine, user_id=user_id, mem_cube_id=mem_cube_id, lock_timeout=lock_timeout + ) + self.obj: MemoryMonitorManager | None = obj + + @property + def orm_class(self) -> type[MemoryMonitorManagerORM]: + return MemoryMonitorManagerORM + + @property + def obj_class(self) -> type[MemoryMonitorManager]: + return MemoryMonitorManager + + def merge_items( + self, + orm_instance: MemoryMonitorManagerORM, + obj_instance: MemoryMonitorManager, + size_limit: int, + ): + """Merge memory monitor items from database with current object + + This method combines items from the database with items in the current + object, prioritizing current object items and applying size limits. + + Args: + orm_instance: ORM instance containing serialized database data + obj_instance: Current MemoryMonitorManager instance + size_limit: Maximum number of items to keep after merge + + Returns: + Updated obj_instance with merged items + """ + logger.debug(f"Starting merge_items for MemoryMonitorManager with size_limit={size_limit}") + + try: + # Deserialize the database instance + db_instance: MemoryMonitorManager = MemoryMonitorManager.from_json( + orm_instance.serialized_data + ) + except Exception as e: + logger.error(f"Failed to deserialize database instance: {e}", exc_info=True) + logger.warning("Skipping merge due to deserialization error, using current object only") + return obj_instance + + # Merge items - prioritize existing ones in current object + merged_items: list[MemoryMonitorItem] = [] + seen_ids = set() + + # First, add all items from current object (higher priority) + for item in obj_instance.memories: + if item.item_id not in seen_ids: + merged_items.append(item) + seen_ids.add(item.item_id) + + # Then, add items from database that aren't in current object + for item in db_instance.memories: + if item.item_id not in seen_ids: + merged_items.append(item) + seen_ids.add(item.item_id) + + # Apply size limit if specified (keep most recent items) + if size_limit is not None and size_limit > 0: + try: + # Sort by sorting_score descending (highest priority first) and take top N + # Note: MemoryMonitorItem doesn't have timestamp, so we use sorting_score instead + merged_items = sorted(merged_items, key=lambda x: x.sorting_score, reverse=True)[ + :size_limit + ] + logger.debug(f"Applied size limit of {size_limit}, kept {len(merged_items)} items") + except AttributeError as e: + logger.error(f"Error sorting MemoryMonitorItem objects: {e}") + logger.error( + "Available attributes: " + + ", ".join(dir(merged_items[0]) if merged_items else []) + ) + raise + except Exception as e: + logger.error(f"Unexpected error during sorting: {e}") + raise + + # Update the object with merged items + obj_instance.memories = merged_items + + logger.info( + f"Merged {len(merged_items)} memory items for {obj_instance} (size_limit: {size_limit})" + ) + + return obj_instance + + +class DBManagerForQueryMonitorQueue(BaseDBManager): + """Database manager for QueryMonitorQueue objects + + This class handles persistence, synchronization, and locking + for QueryMonitorQueue instances in the database. + """ + + def __init__( + self, + engine: Engine, + user_id: str | None = None, + mem_cube_id: str | None = None, + obj: QueryMonitorQueue | None = None, + lock_timeout: int = 10, + ): + """ + Initialize the QueryMonitorQueue database manager. + + Args: + engine: SQLAlchemy engine instance + user_id: Unique identifier for the user + mem_cube_id: Unique identifier for the memory cube + obj: Optional QueryMonitorQueue instance to manage + lock_timeout: Timeout in seconds for lock acquisition + """ + super().__init__( + engine=engine, user_id=user_id, mem_cube_id=mem_cube_id, lock_timeout=lock_timeout + ) + self.obj: QueryMonitorQueue | None = obj + + @property + def orm_class(self) -> type[QueryMonitorQueueORM]: + return QueryMonitorQueueORM + + @property + def obj_class(self) -> type[QueryMonitorQueue]: + return QueryMonitorQueue + + def merge_items( + self, orm_instance: QueryMonitorQueueORM, obj_instance: QueryMonitorQueue, size_limit: int + ): + """Merge query monitor items from database with current queue + + This method combines items from the database with items in the current + queue, prioritizing current queue items and applying size limits. + + Args: + orm_instance: ORM instance containing serialized database data + obj_instance: Current QueryMonitorQueue instance + size_limit: Maximum number of items to keep after merge + + Returns: + Updated obj_instance with merged items + """ + try: + # Deserialize the database instance + db_instance: QueryMonitorQueue = QueryMonitorQueue.from_json( + orm_instance.serialized_data + ) + except Exception as e: + logger.error(f"Failed to deserialize database instance: {e}") + logger.warning("Skipping merge due to deserialization error, using current object only") + return obj_instance + + # Merge items - prioritize existing ones in current object + merged_items: list[QueryMonitorItem] = [] + seen_ids = set() + + # First, add all items from current queue (higher priority) + for item in obj_instance.get_queue_content_without_pop(): + if item.item_id not in seen_ids: + merged_items.append(item) + seen_ids.add(item.item_id) + + # Then, add items from database queue that aren't in current queue + for item in db_instance.get_queue_content_without_pop(): + if item.item_id not in seen_ids: + merged_items.append(item) + seen_ids.add(item.item_id) + + # Apply size limit if specified (keep most recent items) + if size_limit is not None and size_limit > 0: + # Sort by timestamp descending (newest first) and take top N + merged_items = sorted(merged_items, key=lambda x: x.timestamp, reverse=True)[ + :size_limit + ] + + # Update the queue with merged items + obj_instance.clear() # Clear existing items + for item in merged_items: + obj_instance.put(item) # Add merged items back + + logger.info( + f"Merged {len(merged_items)} query items for {obj_instance} (size_limit: {size_limit})" + ) + + return obj_instance diff --git a/src/memos/mem_scheduler/scheduler_factory.py b/src/memos/mem_scheduler/scheduler_factory.py index 5bcd0e2b..3cd406f3 100644 --- a/src/memos/mem_scheduler/scheduler_factory.py +++ b/src/memos/mem_scheduler/scheduler_factory.py @@ -3,6 +3,7 @@ from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.general_scheduler import GeneralScheduler +from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler class SchedulerFactory(BaseScheduler): @@ -10,6 +11,7 @@ class SchedulerFactory(BaseScheduler): backend_to_class: ClassVar[dict[str, Any]] = { "general_scheduler": GeneralScheduler, + "optimized_scheduler": OptimizedScheduler, } @classmethod diff --git a/src/memos/mem_scheduler/schemas/monitor_schemas.py b/src/memos/mem_scheduler/schemas/monitor_schemas.py index 65238d72..f148f30d 100644 --- a/src/memos/mem_scheduler/schemas/monitor_schemas.py +++ b/src/memos/mem_scheduler/schemas/monitor_schemas.py @@ -1,3 +1,4 @@ +import json import threading from collections import Counter @@ -30,6 +31,8 @@ class QueryMonitorItem(BaseModel, DictConversionMixin): item_id: str = Field( description="Unique identifier for the query item", default_factory=lambda: str(uuid4()) ) + user_id: str = Field(..., description="Required user identifier", min_length=1) + mem_cube_id: str = Field(..., description="Required memory cube identifier", min_length=1) query_text: str = Field( ..., description="The actual user query text content", @@ -111,7 +114,8 @@ def get_keywords_collections(self) -> Counter: """ with self.mutex: logger.debug(f"Thread {threading.get_ident()} acquired mutex.") - all_keywords = [kw for item in self.queue for kw in item.keywords] + # Fix: Handle None keywords safely + all_keywords = [kw for item in self.queue if item.keywords for kw in item.keywords] return Counter(all_keywords) def get_queries_with_timesort(self, reverse: bool = True) -> list[str]: @@ -132,9 +136,62 @@ def get_queries_with_timesort(self, reverse: bool = True) -> list[str]: for monitor in sorted(self.queue, key=lambda x: x.timestamp, reverse=reverse) ] + def to_json(self) -> str: + """Serialize the queue to a JSON string. + + Args: + item_serializer: Optional function to serialize individual items. + If not provided, items must be JSON-serializable. + + Returns: + A JSON string representing the queue's content and maxsize. + """ + with self.mutex: + serialized_items = [item.to_json() for item in self.queue] + + data = {"maxsize": self.maxsize, "items": serialized_items} + return json.dumps(data, ensure_ascii=False, indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "QueryMonitorQueue": + """Create a new AutoDroppingQueue from a JSON string. + + Args: + json_str: JSON string created by to_json() + item_deserializer: Optional function to reconstruct items from dicts. + If not provided, items are used as-is. + + Returns: + A new AutoDroppingQueue instance with deserialized data. + """ + data = json.loads(json_str) + maxsize = data.get("maxsize", 0) + item_strs = data.get("items", []) + + queue = cls(maxsize=maxsize) + + items = [QueryMonitorItem.from_json(json_str=item_str) for item_str in item_strs] + + # Fix: Add error handling for put operations + for item in items: + try: + queue.put(item) # Use put() to respect maxsize and auto-drop behavior + except Exception as e: + logger.error(f"Failed to add item to queue: {e}") + # Continue with other items instead of failing completely + + return queue + # ============== Memories ============== class MemoryMonitorItem(BaseModel, DictConversionMixin): + """ + Represents a memory item in the monitoring system. + + Note: This class does NOT have a timestamp field, unlike QueryMonitorItem. + For sorting by recency, use sorting_score or importance_score instead. + """ + item_id: str = Field( description="Unique identifier for the memory item", default_factory=lambda: str(uuid4()) ) @@ -167,7 +224,7 @@ class MemoryMonitorItem(BaseModel, DictConversionMixin): recording_count: int = Field( default=1, description="How many times this memory has been recorded", - ge=1, # Greater than or equal to 1 + ge=1, ) @field_validator("tree_memory_item_mapping_key", mode="before") @@ -177,27 +234,28 @@ def generate_mapping_key(cls, v, values): # noqa: N805 return v def get_importance_score(self, weight_vector: list[float] | None = None) -> float: - """ - Calculate the effective score for the memory item. + return self._get_complex_importance_score(weight_vector=weight_vector) - Returns: - float: The importance_score if it has been initialized (>=0), - otherwise the recording_count converted to float. - - Note: - This method provides a unified way to retrieve a comparable score - for memory items, regardless of whether their importance has been explicitly set. - """ + def _get_complex_importance_score(self, weight_vector: list[float] | None = None) -> float: + """Calculate traditional importance score using existing logic""" if weight_vector is None: - logger.warning("weight_vector of get_importance_score is None.") + logger.warning("weight_vector of get_complex_score is None.") weight_vector = DEFAULT_WEIGHT_VECTOR_FOR_RANKING - assert sum(weight_vector) == 1 - normalized_keywords_score = min(self.keywords_score * weight_vector[1], 5) + + # Fix: Add proper validation for weight_vector + if not weight_vector or len(weight_vector) != 3 or abs(sum(weight_vector) - 1.0) > 1e-6: + raise ValueError("weight_vector must be provided, have length 3, and sum to 1.0") + + # Fix: Handle uninitialized scores safely + sorting_score = self.sorting_score if self.sorting_score != NOT_INITIALIZED else 0.0 + keywords_score = self.keywords_score if self.keywords_score != NOT_INITIALIZED else 0.0 + + normalized_keywords_score = min(keywords_score * weight_vector[1], 5) normalized_recording_count_score = min(self.recording_count * weight_vector[2], 2) self.importance_score = ( - self.sorting_score * weight_vector[0] - + normalized_keywords_score - + normalized_recording_count_score + sorting_score * weight_vector[0] + + normalized_keywords_score * weight_vector[1] + + normalized_recording_count_score * weight_vector[2] ) return self.importance_score @@ -258,7 +316,7 @@ def get_sorted_mem_monitors(self, reverse=True) -> list[MemoryMonitorItem]: def update_memories( self, new_memory_monitors: list[MemoryMonitorItem], partial_retention_number: int - ) -> MemoryMonitorItem: + ) -> list[MemoryMonitorItem]: # Fix: Correct return type """ Update memories based on monitor_working_memories. """ @@ -302,6 +360,13 @@ def update_memories( reverse=True, ) + # Fix: Add bounds checking to prevent IndexError + if partial_retention_number > len(sorted_old_mem_monitors): + partial_retention_number = len(sorted_old_mem_monitors) + logger.info( + f"partial_retention_number adjusted to {partial_retention_number} to match available old memories" + ) + # Keep the top N old memories memories_to_remove = sorted_old_mem_monitors[partial_retention_number:] memories_to_change_score = sorted_old_mem_monitors[:partial_retention_number] @@ -312,19 +377,21 @@ def update_memories( for memory in memories_to_change_score: memory.sorting_score = 0 - memory.recording_count = 0 + memory.recording_count = 1 memory.keywords_score = 0 # Step 4: Enforce max_capacity if set - sorted_memories = sorted( - self.memories, - key=lambda item: item.get_importance_score( - weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING - ), - reverse=True, - ) - # Keep only the top max_capacity memories - self.memories = sorted_memories[: self.max_capacity] + # Fix: Handle max_capacity safely + if self.max_capacity is not None: + sorted_memories = sorted( + self.memories, + key=lambda item: item.get_importance_score( + weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING + ), + reverse=True, + ) + # Keep only the top max_capacity memories + self.memories = sorted_memories[: self.max_capacity] # Log the update result logger.info( diff --git a/src/memos/mem_scheduler/utils/config_utils.py b/src/memos/mem_scheduler/utils/config_utils.py new file mode 100644 index 00000000..8bb1050e --- /dev/null +++ b/src/memos/mem_scheduler/utils/config_utils.py @@ -0,0 +1,100 @@ +import json +import os + +from typing import Any + +import yaml + + +def flatten_dict( + data: dict[str, Any], parent_keys: list[str] | None = None, prefix: str = "" +) -> dict[str, str]: + """ + Recursively flattens a nested dictionary to generate environment variable keys following the specified format. + Combines nested keys with underscores, converts to uppercase, and prepends a custom prefix if provided. + + Args: + data: Nested dictionary to be flattened (parsed from JSON/YAML) + parent_keys: List to track nested keys during recursion + prefix: Custom prefix to be added to all generated keys + + Returns: + Flattened dictionary with keys in PREFIX_KEY1_KEY2... format and string values + """ + parent_keys = parent_keys or [] + flat_data = {} + + for key, value in data.items(): + # Clean and standardize key: convert to uppercase, replace spaces/hyphens with underscores + clean_key = key.upper().replace(" ", "_").replace("-", "_") + current_keys = [*parent_keys, clean_key] + + if isinstance(value, dict): + # Recursively process nested dictionaries + nested_flat = flatten_dict(value, current_keys, prefix) + flat_data.update(nested_flat) + else: + # Construct full key name with prefix (if provided) and nested keys + if prefix: + full_key = f"{prefix.upper()}_{'_'.join(current_keys)}" + else: + full_key = "_".join(current_keys) + + # Process value: ensure string type, convert None to empty string + flat_value = "" if value is None else str(value).strip() + + flat_data[full_key] = flat_value + + return flat_data + + +def convert_config_to_env(input_file: str, output_file: str = ".env", prefix: str = "") -> None: + """ + Converts a JSON or YAML configuration file to a .env file with standardized environment variables. + Uses the flatten_dict function to generate keys in PREFIX_KEY1_KEY2... format. + + Args: + input_file: Path to input configuration file (.json, .yaml, or .yml) + output_file: Path to output .env file (default: .env) + prefix: Custom prefix for all environment variable keys + + Raises: + FileNotFoundError: If input file does not exist + ValueError: If file format is unsupported or parsing fails + """ + # Check if input file exists + if not os.path.exists(input_file): + raise FileNotFoundError(f"Input file not found: {input_file}") + + # Parse input file based on extension + file_ext = os.path.splitext(input_file)[1].lower() + config_data: dict[str, Any] = {} + + try: + with open(input_file, encoding="utf-8") as f: + if file_ext in (".json",): + config_data = json.load(f) + elif file_ext in (".yaml", ".yml"): + config_data = yaml.safe_load(f) + else: + raise ValueError( + f"Unsupported file format: {file_ext}. Supported formats: .json, .yaml, .yml" + ) + except (json.JSONDecodeError, yaml.YAMLError) as e: + raise ValueError(f"Error parsing file: {e!s}") from e + + # Flatten configuration and generate environment variable key-value pairs + flat_config = flatten_dict(config_data, prefix=prefix) + + # Write to .env file + with open(output_file, "w", encoding="utf-8") as f: + for key, value in flat_config.items(): + # Handle values containing double quotes (use no surrounding quotes) + if '"' in value: + f.write(f"{key}={value}\n") + else: + f.write(f'{key}="{value}"\n') # Enclose regular values in double quotes + + print( + f"Conversion complete! Generated {output_file} with {len(flat_config)} environment variables" + ) diff --git a/src/memos/mem_scheduler/utils/db_utils.py b/src/memos/mem_scheduler/utils/db_utils.py new file mode 100644 index 00000000..5d7cc52c --- /dev/null +++ b/src/memos/mem_scheduler/utils/db_utils.py @@ -0,0 +1,33 @@ +import os +import sqlite3 + + +def print_db_tables(db_path: str): + """Print all table names and structures in the SQLite database""" + print(f"\n๐Ÿ” Checking database file: {db_path}") + + if not os.path.exists(db_path): + print(f"โŒ File does not exist! Path: {db_path}") + return + + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # List all tables + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = cursor.fetchall() + if not tables: + print("โŒ Database is empty, no tables created") + else: + print(f"โœ… Database contains {len(tables)} table(s):") + for (table_name,) in tables: + print(f" ๐Ÿ“‚ Table name: {table_name}") + + # Print table structure + cursor.execute(f"PRAGMA table_info({table_name});") + columns = cursor.fetchall() + print(" ๐Ÿงฉ Structure:") + for col in columns: + print(f" {col[1]} ({col[2]}) {'(PK)' if col[5] else ''}") + + conn.close() diff --git a/src/memos/mem_scheduler/utils/filter_utils.py b/src/memos/mem_scheduler/utils/filter_utils.py index 6055fe41..7aa0657e 100644 --- a/src/memos/mem_scheduler/utils/filter_utils.py +++ b/src/memos/mem_scheduler/utils/filter_utils.py @@ -60,7 +60,7 @@ def is_all_chinese(input_string: str) -> bool: install_command="pip install scikit-learn", install_link="https://scikit-learn.org/stable/install.html", ) -def filter_similar_memories( +def filter_vector_based_similar_memories( text_memories: list[str], similarity_threshold: float = 0.75 ) -> list[str]: """ diff --git a/src/memos/mem_scheduler/webservice_modules/__init__.py b/src/memos/mem_scheduler/webservice_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/memos/mem_scheduler/general_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py similarity index 100% rename from src/memos/mem_scheduler/general_modules/rabbitmq_service.py rename to src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py diff --git a/src/memos/mem_scheduler/general_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py similarity index 100% rename from src/memos/mem_scheduler/general_modules/redis_service.py rename to src/memos/mem_scheduler/webservice_modules/redis_service.py diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py index 06cef794..2fa08590 100644 --- a/src/memos/memories/activation/kv.py +++ b/src/memos/memories/activation/kv.py @@ -1,9 +1,10 @@ import os import pickle + from datetime import datetime from importlib.metadata import version -from packaging.version import Version +from packaging.version import Version from transformers import DynamicCache from memos.configs.memory import KVCacheMemoryConfig diff --git a/src/memos/memories/textual/tree_text_memory/organize/handler.py b/src/memos/memories/textual/tree_text_memory/organize/handler.py index a1121fcd..271902ca 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/handler.py +++ b/src/memos/memories/textual/tree_text_memory/organize/handler.py @@ -1,5 +1,6 @@ import json import re + from datetime import datetime from dateutil import parser @@ -14,6 +15,7 @@ MEMORY_RELATION_RESOLVER_PROMPT, ) + logger = get_logger(__name__) diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py index a1fa5324..2281fc77 100644 --- a/src/memos/templates/mem_scheduler_prompts.py +++ b/src/memos/templates/mem_scheduler_prompts.py @@ -151,11 +151,216 @@ Answer: """ +MEMORY_FILTERING_PROMPT = """ +# Memory Relevance Filtering Task + +## Role +You are an intelligent memory filtering system. Your primary function is to analyze memory relevance and filter out memories that are completely unrelated to the user's query history. + +## Task Description +Analyze the provided memories and determine which ones are relevant to the user's query history: +1. Evaluate semantic relationship between each memory and the query history +2. Identify memories that are completely unrelated or irrelevant +3. Filter out memories that don't contribute to answering the queries +4. Preserve memories that provide context, evidence, or relevant information + +## Relevance Criteria +A memory is considered RELEVANT if it: +- Directly answers questions from the query history +- Provides context or background information related to the queries +- Contains information that could be useful for understanding the queries +- Shares semantic similarity with query topics or themes +- Contains keywords or concepts mentioned in the queries + +A memory is considered IRRELEVANT if it: +- Has no semantic connection to any query in the history +- Discusses completely unrelated topics +- Contains information that cannot help answer any query +- Is too generic or vague to be useful + +## Input Format +- Query History: List of user queries (chronological order) +- Memories: List of memory texts to be evaluated + +## Output Format Requirements +You MUST output a valid JSON object with EXACTLY the following structure: +{{ + "relevant_memories": [array_of_memory_indices], + "filtered_count": , + "reasoning": "string_explanation" +}} + +## Important Notes: +- Only output the JSON object, nothing else +- Do not include any markdown formatting or code block notation +- Ensure all brackets and quotes are properly closed +- The output must be parseable by a JSON parser +- Memory indices should correspond to the input order (0-based) + +## Processing Guidelines +1. Be conservative in filtering - when in doubt, keep the memory +2. Consider both direct and indirect relevance +3. Look for thematic connections, not just exact keyword matches +4. Preserve memories that provide valuable context + +## Current Task +Query History: {query_history} +Memories to Filter: {memories} + +Please provide your filtering analysis: +""" + +MEMORY_REDUNDANCY_FILTERING_PROMPT = """ +# Memory Redundancy Filtering Task + +## Role +You are an intelligent memory optimization system. Your primary function is to analyze memories and remove redundancy to improve memory quality and relevance. + +## Task Description +Analyze the provided memories and identify redundant ones: +1. **Redundancy Detection**: Find memories that contain the same core facts relevant to queries +2. **Best Memory Selection**: Keep only the most concise and focused version of redundant information +3. **Quality Preservation**: Ensure the final set covers all necessary information without redundancy + +## Redundancy Detection Criteria +A memory is considered REDUNDANT if it: +- Contains the same core fact as another memory that's relevant to the queries +- Provides the same information but with additional irrelevant details +- Repeats information that's already covered by a more concise memory +- Has overlapping content with another memory that serves the same purpose + +When redundancy is found, KEEP the memory that: +- Is more concise and focused +- Contains less irrelevant information +- Is more directly relevant to the queries +- Has higher information density + +## Input Format +- Query History: List of user queries (chronological order) +- Memories: List of memory texts to be evaluated + +## Output Format Requirements +You MUST output a valid JSON object with EXACTLY the following structure: +{{ + "kept_memories": [array_of_memory_indices_to_keep], + "redundant_groups": [ + {{ + "group_id": , + "memories": [array_of_redundant_memory_indices], + "kept_memory": , + "reason": "explanation_of_why_this_memory_was_kept" + }} + ], + "reasoning": "string_explanation_of_filtering_decisions" +}} + +## Important Notes: +- Only output the JSON object, nothing else +- Do not include any markdown formatting or code block notation +- Ensure all brackets and quotes are properly closed +- The output must be parseable by a JSON parser +- Memory indices should correspond to the input order (0-based) +- Be conservative in filtering - when in doubt, keep the memory +- Focus on semantic similarity, not just exact text matches + +## Processing Guidelines +1. First identify which memories are relevant to the queries +2. Group relevant memories by semantic similarity and core facts +3. Within each group, select the best memory (most concise, least noise) +4. Ensure the final set covers all necessary information without redundancy + +## Current Task +Query History: {query_history} +Memories to Filter: {memories} + +Please provide your redundancy filtering analysis: +""" + +MEMORY_COMBINED_FILTERING_PROMPT = """ +# Memory Combined Filtering Task + +## Role +You are an intelligent memory optimization system. Your primary function is to analyze memories and perform two types of filtering in sequence: +1. **Unrelated Memory Removal**: Remove memories that are completely unrelated to the user's query history +2. **Redundancy Removal**: Remove redundant memories by keeping only the most informative version + +## Task Description +Analyze the provided memories and perform comprehensive filtering: +1. **First Step - Unrelated Filtering**: Identify and remove memories that have no semantic connection to any query +2. **Second Step - Redundancy Filtering**: Group similar memories and keep only the best version from each group + +## Unrelated Memory Detection Criteria +A memory is considered UNRELATED if it: +- Has no semantic connection to any query in the history +- Discusses completely unrelated topics +- Contains information that cannot help answer any query +- Is too generic or vague to be useful + +## Redundancy Detection Criteria +A memory is considered REDUNDANT if it: +- Contains the same core fact as another memory that's relevant to the queries +- Provides the same information but with additional irrelevant details +- Repeats information that's already covered by a more concise memory +- Has overlapping content with another memory that serves the same purpose + +When redundancy is found, KEEP the memory that: +- Is more concise and focused +- Contains less irrelevant information +- Is more directly relevant to the queries +- Has higher information density + +## Input Format +- Query History: List of user queries (chronological order) +- Memories: List of memory texts to be evaluated + +## Output Format Requirements +You MUST output a valid JSON object with EXACTLY the following structure: +{{ + "kept_memories": [array_of_memory_indices_to_keep], + "unrelated_removed_count": , + "redundant_removed_count": , + "redundant_groups": [ + {{ + "group_id": , + "memories": [array_of_redundant_memory_indices], + "kept_memory": , + "reason": "explanation_of_why_this_memory_was_kept" + }} + ], + "reasoning": "string_explanation_of_filtering_decisions" +}} + +## Important Notes: +- Only output the JSON object, nothing else +- Do not include any markdown formatting or code block notation +- Ensure all brackets and quotes are properly closed +- The output must be parseable by a JSON parser +- Memory indices should correspond to the input order (0-based) +- Be conservative in filtering - when in doubt, keep the memory +- Focus on semantic similarity, not just exact text matches + +## Processing Guidelines +1. **First, identify unrelated memories** and mark them for removal +2. **Then, group remaining memories** by semantic similarity and core facts +3. **Within each group, select the best memory** (most concise, least noise) +4. **Ensure the final set covers all necessary information** without redundancy +5. **Count how many memories were removed** for each reason + +## Current Task +Query History: {query_history} +Memories to Filter: {memories} + +Please provide your combined filtering analysis: +""" + PROMPT_MAPPING = { "intent_recognizing": INTENT_RECOGNIZING_PROMPT, "memory_reranking": MEMORY_RERANKING_PROMPT, "query_keywords_extraction": QUERY_KEYWORDS_EXTRACTION_PROMPT, + "memory_filtering": MEMORY_FILTERING_PROMPT, + "memory_redundancy_filtering": MEMORY_REDUNDANCY_FILTERING_PROMPT, + "memory_combined_filtering": MEMORY_COMBINED_FILTERING_PROMPT, } MEMORY_ASSEMBLY_TEMPLATE = """The retrieved memories are listed as follows:\n\n {memory_text}""" diff --git a/tests/mem_scheduler/test_config.py b/tests/mem_scheduler/test_config.py new file mode 100644 index 00000000..b389220a --- /dev/null +++ b/tests/mem_scheduler/test_config.py @@ -0,0 +1,319 @@ +import os +import sys +import unittest + +from pathlib import Path +from tempfile import NamedTemporaryFile, TemporaryDirectory + +from memos.configs.mem_scheduler import AuthConfig, GraphDBAuthConfig, OpenAIConfig, RabbitMQConfig +from memos.mem_scheduler.general_modules.misc import EnvConfigMixin +from memos.mem_scheduler.utils.config_utils import convert_config_to_env, flatten_dict + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) + +ENV_PREFIX = EnvConfigMixin.ENV_PREFIX + + +class TestEnvConfigMixin(unittest.TestCase): + """Tests specifically for the EnvConfigMixin functionality""" + + def test_env_prefix_class_variable(self): + """Verify the base environment prefix is set correctly""" + self.assertEqual(EnvConfigMixin.ENV_PREFIX, "MEMSCHEDULER_") + + def test_get_env_prefix_generation(self): + """Test the dynamic environment variable prefix generation""" + # Test GraphDBAuthConfig specifically since it's causing issues + self.assertEqual( + GraphDBAuthConfig.get_env_prefix(), + f"{ENV_PREFIX}GRAPHDBAUTH_", # Critical: This is the correct prefix! + ) + + # Verify other configs + self.assertEqual(RabbitMQConfig.get_env_prefix(), f"{ENV_PREFIX}RABBITMQ_") + self.assertEqual(OpenAIConfig.get_env_prefix(), f"{ENV_PREFIX}OPENAI_") + + +class TestSchedulerConfig(unittest.TestCase): + def setUp(self): + self.env_backup = dict(os.environ) + self._clear_prefixed_env_vars() + + def tearDown(self): + os.environ.clear() + os.environ.update(self.env_backup) + + def _clear_prefixed_env_vars(self): + for key in list(os.environ.keys()): + if key.startswith(ENV_PREFIX): + del os.environ[key] + + def test_loads_all_configs_from_env(self): + """Test loading all configurations from prefixed environment variables""" + os.environ.update( + { + # RabbitMQ configs + f"{ENV_PREFIX}RABBITMQ_HOST_NAME": "rabbit.test.com", + f"{ENV_PREFIX}RABBITMQ_USER_NAME": "test_user", + f"{ENV_PREFIX}RABBITMQ_PASSWORD": "test_pass", + f"{ENV_PREFIX}RABBITMQ_VIRTUAL_HOST": "test_vhost", + f"{ENV_PREFIX}RABBITMQ_ERASE_ON_CONNECT": "false", + f"{ENV_PREFIX}RABBITMQ_PORT": "5673", + # OpenAI configs + f"{ENV_PREFIX}OPENAI_API_KEY": "test_api_key", + f"{ENV_PREFIX}OPENAI_BASE_URL": "https://api.test.openai.com", + f"{ENV_PREFIX}OPENAI_DEFAULT_MODEL": "gpt-test", + # GraphDBAuthConfig configs - NOTE THE CORRECT PREFIX! + f"{ENV_PREFIX}GRAPHDBAUTH_URI": "bolt://test.db:7687", + f"{ENV_PREFIX}GRAPHDBAUTH_USER": "test_neo4j", + f"{ENV_PREFIX}GRAPHDBAUTH_PASSWORD": "test_db_pass_123", # 13 chars (valid) + f"{ENV_PREFIX}GRAPHDBAUTH_DB_NAME": "test_db", + f"{ENV_PREFIX}GRAPHDBAUTH_AUTO_CREATE": "false", + } + ) + + config = AuthConfig.from_local_env() + + # Verify GraphDB configuration + self.assertEqual(config.graph_db.uri, "bolt://test.db:7687") + self.assertEqual(config.graph_db.user, "test_neo4j") + self.assertEqual(config.graph_db.password, "test_db_pass_123") + self.assertEqual(config.graph_db.db_name, "test_db") + self.assertFalse(config.graph_db.auto_create) + + def test_uses_default_values_when_env_not_set(self): + """Test that default values are used when prefixed environment variables are not set""" + os.environ.update( + { + # RabbitMQ + f"{ENV_PREFIX}RABBITMQ_HOST_NAME": "rabbit.test.com", + # OpenAI + f"{ENV_PREFIX}OPENAI_API_KEY": "test_api_key", + # GraphDB - with correct prefix and valid password length + f"{ENV_PREFIX}GRAPHDBAUTH_URI": "bolt://test.db:7687", + f"{ENV_PREFIX}GRAPHDBAUTH_PASSWORD": "default_pass", # 11 chars (valid) + } + ) + + config = AuthConfig.from_local_env() + + # Verify default values take effect + self.assertEqual(config.rabbitmq.port, 5672) # RabbitMQ default port + self.assertTrue(config.graph_db.auto_create) # GraphDB default auto-create + + def test_raises_on_missing_required_variables(self): + """Test that exceptions are raised when required prefixed variables are missing""" + with self.assertRaises((ValueError, Exception)) as context: + AuthConfig.from_local_env() + + error_msg = str(context.exception).lower() + self.assertTrue( + "missing" in error_msg or "validation" in error_msg or "required" in error_msg, + f"Error message does not meet expectations: {error_msg}", + ) + + def test_type_conversion(self): + """Test type conversion for prefixed environment variables""" + os.environ.update( + { + # RabbitMQ + f"{ENV_PREFIX}RABBITMQ_HOST_NAME": "rabbit.test.com", + f"{ENV_PREFIX}RABBITMQ_PORT": "1234", + f"{ENV_PREFIX}RABBITMQ_ERASE_ON_CONNECT": "yes", + # OpenAI + f"{ENV_PREFIX}OPENAI_API_KEY": "test_api_key", + # GraphDB - correct prefix and valid password + f"{ENV_PREFIX}GRAPHDBAUTH_URI": "bolt://test.db:7687", + f"{ENV_PREFIX}GRAPHDBAUTH_PASSWORD": "type_conv_pass", # 13 chars (valid) + f"{ENV_PREFIX}GRAPHDBAUTH_AUTO_CREATE": "0", + } + ) + + config = AuthConfig.from_local_env() + + # Verify type conversion results + self.assertIsInstance(config.rabbitmq.port, int) + self.assertIsInstance(config.rabbitmq.erase_on_connect, bool) + self.assertIsInstance(config.graph_db.auto_create, bool) + self.assertTrue(config.rabbitmq.erase_on_connect) + self.assertFalse(config.graph_db.auto_create) + + def test_combined_with_local_config(self): + """Test priority between prefixed environment variables and config files""" + with NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as f: + f.write(""" + rabbitmq: + host_name: "file.rabbit.com" + port: 1234 + openai: + api_key: "file_api_key" + graph_db: + uri: "bolt://file.db:7687" + password: "file_db_pass" + """) + config_file_path = f.name + + try: + # Environment variables with correct prefixes + os.environ.update( + { + f"{ENV_PREFIX}RABBITMQ_HOST_NAME": "env.rabbit.com", + f"{ENV_PREFIX}OPENAI_API_KEY": "env_api_key", + f"{ENV_PREFIX}GRAPHDBAUTH_USER": "env_user", + f"{ENV_PREFIX}GRAPHDBAUTH_PASSWORD": "env_db_pass", # 11 chars (valid) + } + ) + + # 1. Test loading from config file + file_config = AuthConfig.from_local_config(Path(config_file_path)) + self.assertEqual(file_config.rabbitmq.host_name, "file.rabbit.com") + self.assertEqual(file_config.rabbitmq.port, 1234) + self.assertEqual(file_config.openai.api_key, "file_api_key") + self.assertEqual(file_config.graph_db.password, "file_db_pass") + + # 2. Test loading from environment variables + env_config = AuthConfig.from_local_env() + self.assertEqual(env_config.rabbitmq.host_name, "env.rabbit.com") + self.assertEqual(env_config.openai.api_key, "env_api_key") + self.assertEqual(env_config.graph_db.user, "env_user") + self.assertEqual(env_config.graph_db.password, "env_db_pass") + self.assertEqual(env_config.rabbitmq.port, 5672) + + finally: + os.unlink(config_file_path) + + +class TestConfigUtils(unittest.TestCase): + """Tests for config_utils functions: flatten_dict and convert_config_to_env""" + + def test_flatten_dict_basic(self): + """Test basic dictionary flattening without prefix""" + input_dict = {"database": {"host": "localhost", "port": 5432}, "auth": {"enabled": True}} + + expected = {"DATABASE_HOST": "localhost", "DATABASE_PORT": "5432", "AUTH_ENABLED": "True"} + + self.assertEqual(flatten_dict(input_dict), expected) + + def test_flatten_dict_with_prefix(self): + """Test dictionary flattening with a custom prefix""" + input_dict = {"rabbitmq": {"host": "rabbit.local"}} + + expected = {"APP_RABBITMQ_HOST": "rabbit.local"} + + self.assertEqual(flatten_dict(input_dict, prefix="app"), expected) + + def test_flatten_dict_special_chars(self): + """Test handling of spaces and hyphens in keys""" + input_dict = {"my key": "value", "other-key": {"nested key": 123}} + + expected = {"MY_KEY": "value", "OTHER_KEY_NESTED_KEY": "123"} + + self.assertEqual(flatten_dict(input_dict), expected) + + def test_flatten_dict_none_values(self): + """Test handling of None values""" + input_dict = {"optional": None, "required": "present"} + + expected = {"OPTIONAL": "", "REQUIRED": "present"} + + self.assertEqual(flatten_dict(input_dict), expected) + + def test_convert_json_to_env(self): + """Test conversion from JSON to .env file""" + with TemporaryDirectory() as temp_dir: + input_path = os.path.join(temp_dir, "config.json") + output_path = os.path.join(temp_dir, ".env") + + # Create test JSON file + with open(input_path, "w") as f: + f.write('{"server": {"port": 8080}, "debug": false}') + + # Convert to .env + convert_config_to_env(input_path, output_path, prefix="app") + + # Verify output + with open(output_path) as f: + content = f.read() + + self.assertIn('APP_SERVER_PORT="8080"', content) + self.assertIn('APP_DEBUG="False"', content) + + def test_convert_yaml_to_env(self): + """Test conversion from YAML to .env file""" + with TemporaryDirectory() as temp_dir: + input_path = os.path.join(temp_dir, "config.yaml") + output_path = os.path.join(temp_dir, ".env") + + # Create test YAML file + with open(input_path, "w") as f: + f.write(""" + database: + host: db.example.com + credentials: + user: admin + pass: secret + """) + + # Convert to .env + convert_config_to_env(input_path, output_path) + + # Verify output + with open(output_path) as f: + content = f.read() + + self.assertIn('DATABASE_HOST="db.example.com"', content) + self.assertIn('DATABASE_CREDENTIALS_USER="admin"', content) + self.assertIn('DATABASE_CREDENTIALS_PASS="secret"', content) + + def test_convert_with_special_values(self): + """Test conversion with values containing quotes and special characters""" + with TemporaryDirectory() as temp_dir: + input_path = os.path.join(temp_dir, "config.json") + output_path = os.path.join(temp_dir, ".env") + + # Create test JSON with special values + with open(input_path, "w") as f: + f.write('{"description": "Hello \\"World\\"", "empty": null}') + + # Convert to .env + convert_config_to_env(input_path, output_path) + + # Verify output + with open(output_path) as f: + content = f.read() + + # Values with double quotes should not have surrounding quotes + self.assertIn('DESCRIPTION=Hello "World"', content) + self.assertIn('EMPTY=""', content) + + def test_unsupported_file_format(self): + """Test error handling for unsupported file formats""" + with TemporaryDirectory() as temp_dir: + input_path = os.path.join(temp_dir, "config.txt") + with open(input_path, "w") as f: + f.write("some content") + + with self.assertRaises(ValueError) as context: + convert_config_to_env(input_path) + + self.assertIn("Unsupported file format", str(context.exception)) + + def test_file_not_found(self): + """Test error handling for non-existent input file""" + with self.assertRaises(FileNotFoundError): + convert_config_to_env("non_existent_file.json") + + def test_invalid_json(self): + """Test error handling for invalid JSON""" + with TemporaryDirectory() as temp_dir: + input_path = os.path.join(temp_dir, "bad.json") + with open(input_path, "w") as f: + f.write('{"invalid": json}') # Invalid JSON + + with self.assertRaises(ValueError) as context: + convert_config_to_env(input_path) + + self.assertIn("Error parsing file", str(context.exception)) diff --git a/tests/mem_scheduler/test_orm.py b/tests/mem_scheduler/test_orm.py new file mode 100644 index 00000000..ddf4fea8 --- /dev/null +++ b/tests/mem_scheduler/test_orm.py @@ -0,0 +1,299 @@ +import os +import tempfile +import time + +from datetime import datetime, timedelta + +import pytest + +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager + +# Import the classes to test +from memos.mem_scheduler.orm_modules.monitor_models import ( + DBManagerForMemoryMonitorManager, + DBManagerForQueryMonitorQueue, +) +from memos.mem_scheduler.schemas.monitor_schemas import ( + MemoryMonitorItem, + MemoryMonitorManager, + QueryMonitorItem, + QueryMonitorQueue, +) + + +# Test data +TEST_USER_ID = "test_user" +TEST_MEM_CUBE_ID = "test_mem_cube" +TEST_QUEUE_ID = "test_queue" + + +class TestBaseDBManager: + """Base class for DBManager tests with common fixtures""" + + @pytest.fixture + def temp_db(self): + """Create a temporary database for testing.""" + temp_dir = tempfile.mkdtemp() + db_path = os.path.join(temp_dir, "test_scheduler_orm.db") + yield db_path + # Cleanup + try: + if os.path.exists(db_path): + os.remove(db_path) + os.rmdir(temp_dir) + except (OSError, PermissionError): + pass # Ignore cleanup errors (e.g., file locked on Windows) + + @pytest.fixture + def memory_manager_obj(self): + """Create a MemoryMonitorManager object for testing""" + return MemoryMonitorManager( + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + items=[ + MemoryMonitorItem( + item_id="custom-id-123", + memory_text="Full test memory", + tree_memory_item=None, + tree_memory_item_mapping_key="full_test_key", + keywords_score=0.8, + sorting_score=0.9, + importance_score=0.7, + recording_count=3, + ) + ], + ) + + @pytest.fixture + def query_queue_obj(self): + """Create a QueryMonitorQueue object for testing""" + queue = QueryMonitorQueue() + queue.put( + QueryMonitorItem( + item_id="query1", + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + query_text="How are you?", + timestamp=datetime.now(), + keywords=["how", "you"], + ) + ) + return queue + + @pytest.fixture + def query_monitor_manager(self, temp_db, query_queue_obj): + """Create DBManagerForQueryMonitorQueue instance with temp DB.""" + engine = BaseDBManager.create_engine_from_db_path(temp_db) + manager = DBManagerForQueryMonitorQueue( + engine=engine, + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + obj=query_queue_obj, + lock_timeout=10, + ) + + assert manager.engine is not None + assert manager.SessionLocal is not None + assert os.path.exists(temp_db) + + yield manager + manager.close() + + @pytest.fixture + def memory_monitor_manager(self, temp_db, memory_manager_obj): + """Create DBManagerForMemoryMonitorManager instance with temp DB.""" + engine = BaseDBManager.create_engine_from_db_path(temp_db) + manager = DBManagerForMemoryMonitorManager( + engine=engine, + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + obj=memory_manager_obj, + lock_timeout=10, + ) + + assert manager.engine is not None + assert manager.SessionLocal is not None + assert os.path.exists(temp_db) + + yield manager + manager.close() + + def test_save_and_load_query_queue(self, query_monitor_manager, query_queue_obj): + """Test saving and loading QueryMonitorQueue.""" + # Save to database + query_monitor_manager.save_to_db(query_queue_obj) + + # Load in a new manager + engine = BaseDBManager.create_engine_from_db_path(query_monitor_manager.engine.url.database) + new_manager = DBManagerForQueryMonitorQueue( + engine=engine, + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + obj=None, + lock_timeout=10, + ) + loaded_queue = new_manager.load_from_db(acquire_lock=True) + + assert loaded_queue is not None + items = loaded_queue.get_queue_content_without_pop() + assert len(items) == 1 + assert items[0].item_id == "query1" + assert items[0].query_text == "How are you?" + new_manager.close() + + def test_lock_mechanism(self, query_monitor_manager, query_queue_obj): + """Test lock acquisition and release.""" + # Save current state + query_monitor_manager.save_to_db(query_queue_obj) + + # Acquire lock + acquired = query_monitor_manager.acquire_lock(block=True) + assert acquired + + # Try to acquire again (should fail without blocking) + assert not query_monitor_manager.acquire_lock(block=False) + + # Release lock + query_monitor_manager.release_locks( + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + ) + + # Should be able to acquire again + assert query_monitor_manager.acquire_lock(block=False) + + def test_lock_timeout(self, query_monitor_manager, query_queue_obj): + """Test lock timeout mechanism.""" + # Save current state + query_monitor_manager.save_to_db(query_queue_obj) + + query_monitor_manager.lock_timeout = 1 + + # Acquire lock + assert query_monitor_manager.acquire_lock(block=True) + + # Wait for lock to expire + time.sleep(1.1) + + # Should be able to acquire again + assert query_monitor_manager.acquire_lock(block=False) + + def test_sync_with_orm(self, query_monitor_manager, query_queue_obj): + """Test synchronization between ORM and object.""" + query_queue_obj.put( + QueryMonitorItem( + item_id="query2", + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + query_text="What's your name?", + timestamp=datetime.now(), + keywords=["name"], + ) + ) + + # Save current state + query_monitor_manager.save_to_db(query_queue_obj) + + # Create sync manager with empty queue + empty_queue = QueryMonitorQueue(maxsize=10) + engine = BaseDBManager.create_engine_from_db_path(query_monitor_manager.engine.url.database) + sync_manager = DBManagerForQueryMonitorQueue( + engine=engine, + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + obj=empty_queue, + lock_timeout=10, + ) + + # First sync - should create a new record with empty queue + sync_manager.sync_with_orm(size_limit=None) + items = sync_manager.obj.get_queue_content_without_pop() + assert len(items) == 0 # Empty queue since no existing data to merge + + # Now save the empty queue to create a record + sync_manager.save_to_db(empty_queue) + + # Test that sync_with_orm correctly handles version control + # The sync should increment version but not merge data when versions are the same + sync_manager.sync_with_orm(size_limit=None) + items = sync_manager.obj.get_queue_content_without_pop() + assert len(items) == 0 # Should remain empty since no merge occurred + + # Verify that the version was incremented + assert sync_manager.last_version_control == "3" # Should increment from 2 to 3 + + sync_manager.close() + + def test_sync_with_size_limit(self, query_monitor_manager, query_queue_obj): + """Test synchronization with size limit.""" + now = datetime.now() + item_size = 1 + for i in range(2, 6): + item_size += 1 + query_queue_obj.put( + QueryMonitorItem( + item_id=f"query{i}", + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + query_text=f"Question {i}", + timestamp=now + timedelta(minutes=i), + keywords=[f"kw{i}"], + ) + ) + + # First sync - should create a new record (size_limit not applied for new records) + size_limit = 3 + query_monitor_manager.sync_with_orm(size_limit=size_limit) + items = query_monitor_manager.obj.get_queue_content_without_pop() + assert len(items) == item_size # All items since size_limit not applied for new records + + # Save to create the record + query_monitor_manager.save_to_db(query_monitor_manager.obj) + + # Test that sync_with_orm correctly handles version control + # The sync should increment version but not merge data when versions are the same + query_monitor_manager.sync_with_orm(size_limit=size_limit) + items = query_monitor_manager.obj.get_queue_content_without_pop() + assert len(items) == item_size # Should remain the same since no merge occurred + + # Verify that the version was incremented + assert query_monitor_manager.last_version_control == "2" + + def test_concurrent_access(self, temp_db, query_queue_obj): + """Test concurrent access to the same database.""" + + # Manager 1 + engine1 = BaseDBManager.create_engine_from_db_path(temp_db) + manager1 = DBManagerForQueryMonitorQueue( + engine=engine1, + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + obj=query_queue_obj, + lock_timeout=10, + ) + manager1.save_to_db(query_queue_obj) + + # Manager 2 + engine2 = BaseDBManager.create_engine_from_db_path(temp_db) + manager2 = DBManagerForQueryMonitorQueue( + engine=engine2, + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + obj=query_queue_obj, + lock_timeout=10, + ) + + # Manager1 acquires lock + assert manager1.acquire_lock(block=True) + + # Manager2 fails to acquire + assert not manager2.acquire_lock(block=False) + + # Manager1 releases + manager1.release_locks(user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID) + + # Manager2 can now acquire + assert manager2.acquire_lock(block=False) + + manager1.close() + manager2.close() diff --git a/tests/mem_scheduler/test_retriever.py b/tests/mem_scheduler/test_retriever.py index 0ef6eb8e..35c8b7f3 100644 --- a/tests/mem_scheduler/test_retriever.py +++ b/tests/mem_scheduler/test_retriever.py @@ -1,18 +1,25 @@ +import json import sys import unittest from pathlib import Path from unittest.mock import MagicMock, patch -from memos.configs.mem_scheduler import SchedulerConfigFactory +from memos.configs.mem_scheduler import ( + AuthConfig, + GraphDBAuthConfig, + OpenAIConfig, + RabbitMQConfig, + SchedulerConfigFactory, +) from memos.llms.base import BaseLLM from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.utils.filter_utils import ( - filter_similar_memories, filter_too_short_memories, + filter_vector_based_similar_memories, ) -from memos.memories.textual.tree import TreeTextMemory +from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory FILE_PATH = Path(__file__).absolute() @@ -21,6 +28,25 @@ class TestSchedulerRetriever(unittest.TestCase): + def _create_mock_auth_config(self): + """Create a mock AuthConfig for testing purposes.""" + # Create mock configs with valid test values + graph_db_config = GraphDBAuthConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test_password_123", # 8+ characters to pass validation + db_name="neo4j", + auto_create=True, + ) + + rabbitmq_config = RabbitMQConfig( + host_name="localhost", port=5672, user_name="guest", password="guest", virtual_host="/" + ) + + openai_config = OpenAIConfig(api_key="test_api_key_123", default_model="gpt-3.5-turbo") + + return AuthConfig(rabbitmq=rabbitmq_config, openai=openai_config, graph_db=graph_db_config) + def setUp(self): """Initialize test environment with mock objects.""" example_scheduler_config_path = ( @@ -37,6 +63,13 @@ def setUp(self): self.mem_cube.text_mem = self.tree_text_memory self.mem_cube.act_mem = MagicMock() + # Mock AuthConfig.from_local_env() to return our test config + mock_auth_config = self._create_mock_auth_config() + self.auth_config_patch = patch( + "memos.configs.mem_scheduler.AuthConfig.from_local_env", return_value=mock_auth_config + ) + self.auth_config_patch.start() + # Initialize general_modules with mock LLM self.scheduler.initialize_modules(chat_llm=self.llm, process_llm=self.llm) self.scheduler.mem_cube = self.mem_cube @@ -47,17 +80,21 @@ def setUp(self): self.logging_warning_patch = patch("logging.warning") self.mock_logging_warning = self.logging_warning_patch.start() - self.logger_info_patch = patch("memos.mem_scheduler.general_modules.retriever.logger.info") + # Mock the MemoryFilter logger since that's where the actual logging happens + self.logger_info_patch = patch( + "memos.mem_scheduler.memory_manage_modules.memory_filter.logger.info" + ) self.mock_logger_info = self.logger_info_patch.start() def tearDown(self): """Clean up patches.""" self.logging_warning_patch.stop() self.logger_info_patch.stop() + self.auth_config_patch.stop() def test_filter_similar_memories_empty_input(self): """Test filter_similar_memories with empty input list.""" - result = filter_similar_memories([]) + result = filter_vector_based_similar_memories([]) self.assertEqual(result, []) def test_filter_similar_memories_no_duplicates(self): @@ -68,7 +105,7 @@ def test_filter_similar_memories_no_duplicates(self): "And this third one has nothing in common with the others", ] - result = filter_similar_memories(memories) + result = filter_vector_based_similar_memories(memories) self.assertEqual(len(result), 3) self.assertEqual(set(result), set(memories)) @@ -79,14 +116,14 @@ def test_filter_similar_memories_with_duplicates(self): "The user is planning to move to Chicago next month, which reflects a significant change in their living situation.", "The user is planning to move to Chicago in the upcoming month, indicating a significant change in their living situation.", ] - result = filter_similar_memories(memories, similarity_threshold=0.75) + result = filter_vector_based_similar_memories(memories, similarity_threshold=0.75) self.assertLess(len(result), len(memories)) def test_filter_similar_memories_error_handling(self): """Test filter_similar_memories error handling.""" # Test with non-string input (should return original list due to error) memories = ["valid text", 12345, "another valid text"] - result = filter_similar_memories(memories) + result = filter_vector_based_similar_memories(memories) self.assertEqual(result, memories) def test_filter_too_short_memories_empty_input(self): @@ -134,3 +171,192 @@ def test_filter_too_short_memories_edge_case(self): ) # "Exactly three words here", "Two words only", "Four words right here" self.assertIn("Exactly three words here", result) self.assertIn("Four words right here", result) + + def test_filter_unrelated_memories_empty_memories(self): + """Test filter_unrelated_memories with empty memories list.""" + query_history = ["What is the weather like?", "Tell me about Python programming"] + + result, success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, memories=[] + ) + + self.assertEqual(result, []) + self.assertTrue(success_flag) + self.mock_logger_info.assert_called_with("No memories to filter - returning empty list") + + def test_filter_unrelated_memories_empty_query_history(self): + """Test filter_unrelated_memories with empty query history.""" + memories = [ + TextualMemoryItem(memory="Python is a programming language"), + TextualMemoryItem(memory="Machine learning uses algorithms"), + TextualMemoryItem(memory="Data science involves statistics"), + ] + + result, success_flag = self.retriever.filter_unrelated_memories( + query_history=[], memories=memories + ) + + self.assertEqual(result, memories) + self.assertTrue(success_flag) + self.mock_logger_info.assert_called_with("No query history provided - keeping all memories") + + def test_filter_unrelated_memories_successful_filtering(self): + """Test filter_unrelated_memories with successful LLM filtering.""" + query_history = ["What is Python?", "How does machine learning work?"] + memories = [ + TextualMemoryItem(memory="Python is a high-level programming language"), + TextualMemoryItem(memory="Machine learning algorithms learn from data"), + TextualMemoryItem(memory="The weather is sunny today"), # Unrelated + TextualMemoryItem(memory="Python has many libraries for ML"), + TextualMemoryItem(memory="Cooking recipes for pasta"), # Unrelated + ] + + # Mock LLM response for successful filtering + mock_llm_response = { + "relevant_memories": [0, 1, 3], # Keep Python, ML, and Python ML libraries + "filtered_count": 2, # Filter out weather and cooking + "reasoning": "Kept memories related to Python and machine learning, filtered out unrelated topics", + } + + # Convert to proper JSON string + self.llm.generate.return_value = json.dumps(mock_llm_response) + + result, success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, memories=memories + ) + + # Verify results + self.assertEqual(len(result), 3) + self.assertIn(memories[0], result) # Python + self.assertIn(memories[1], result) # ML + self.assertIn(memories[3], result) # Python ML libraries + self.assertNotIn(memories[2], result) # Weather + self.assertNotIn(memories[4], result) # Cooking + self.assertTrue(success_flag) + + # Verify LLM was called correctly + self.llm.generate.assert_called_once() + call_args = self.llm.generate.call_args[0][0] + self.assertEqual(call_args[0]["role"], "user") + self.assertIn("Memory Relevance Filtering Task", call_args[0]["content"]) + + def test_filter_unrelated_memories_llm_failure_fallback(self): + """Test filter_unrelated_memories with LLM failure - should fallback to keeping all memories.""" + query_history = ["What is Python?"] + memories = [ + TextualMemoryItem(memory="Python is a programming language"), + TextualMemoryItem(memory="Machine learning is a subset of AI"), + ] + + # Mock LLM to return an invalid response that will trigger error handling + self.llm.generate.return_value = "Invalid response that cannot be parsed" + + result, success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, memories=memories + ) + + # Should return all memories as fallback + self.assertEqual(result, memories) + self.assertFalse(success_flag) + + # Verify error was logged + self.mock_logger_info.assert_called_with( + "Starting memory filtering for 2 memories against 1 queries" + ) + + def test_filter_unrelated_memories_invalid_json_response(self): + """Test filter_unrelated_memories with invalid JSON response from LLM.""" + query_history = ["What is Python?"] + memories = [ + TextualMemoryItem(memory="Python is a programming language"), + TextualMemoryItem(memory="Machine learning is a subset of AI"), + ] + + # Mock LLM to return invalid JSON + self.llm.generate.return_value = "This is not valid JSON" + + result, success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, memories=memories + ) + + # Should return all memories as fallback + self.assertEqual(result, memories) + self.assertFalse(success_flag) + + def test_filter_unrelated_memories_invalid_indices(self): + """Test filter_unrelated_memories with invalid indices in LLM response.""" + query_history = ["What is Python?"] + memories = [ + TextualMemoryItem(memory="Python is a programming language"), + TextualMemoryItem(memory="Machine learning is a subset of AI"), + ] + + # Mock LLM to return invalid indices + mock_llm_response = { + "relevant_memories": [0, 5, -1], # Invalid indices + "filtered_count": 1, + "reasoning": "Some memories are relevant", + } + + # Convert to proper JSON string + self.llm.generate.return_value = json.dumps(mock_llm_response) + + result, success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, memories=memories + ) + + # Should only include valid indices + self.assertEqual(len(result), 1) + self.assertIn(memories[0], result) # Index 0 is valid + self.assertTrue(success_flag) + + def test_filter_unrelated_memories_missing_required_fields(self): + """Test filter_unrelated_memories with missing required fields in LLM response.""" + query_history = ["What is Python?"] + memories = [ + TextualMemoryItem(memory="Python is a programming language"), + TextualMemoryItem(memory="Machine learning is a subset of AI"), + ] + + # Mock LLM to return response missing required fields + mock_llm_response = { + "relevant_memories": [0, 1] + # Missing "filtered_count" and "reasoning" + } + + # Convert to proper JSON string + self.llm.generate.return_value = json.dumps(mock_llm_response) + + result, success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, memories=memories + ) + + # Should return all memories as fallback due to missing fields + self.assertEqual(result, memories) + self.assertFalse(success_flag) + + def test_filter_unrelated_memories_conservative_filtering(self): + """Test that filter_unrelated_memories uses conservative approach - keeps memories when in doubt.""" + query_history = ["What is Python?"] + memories = [ + TextualMemoryItem(memory="Python is a programming language"), + TextualMemoryItem(memory="Machine learning is a subset of AI"), + TextualMemoryItem(memory="The weather is sunny today"), # Potentially unrelated + ] + + # Mock LLM to return all memories as relevant (conservative) + mock_llm_response = { + "relevant_memories": [0, 1, 2], # Keep all memories + "filtered_count": 0, # No filtering + "reasoning": "All memories could potentially provide context", + } + + self.llm.generate.return_value = json.dumps(mock_llm_response) + + result, success_flag = self.retriever.filter_unrelated_memories( + query_history=query_history, memories=memories + ) + + # Should return all memories + self.assertEqual(result, memories) + self.assertTrue(success_flag) diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index 97377738..51ea5677 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -3,12 +3,18 @@ from datetime import datetime from pathlib import Path -from unittest.mock import MagicMock - -from memos.configs.mem_scheduler import SchedulerConfigFactory +from unittest.mock import MagicMock, patch + +from memos.configs.mem_scheduler import ( + AuthConfig, + GraphDBAuthConfig, + OpenAIConfig, + RabbitMQConfig, + SchedulerConfigFactory, +) from memos.llms.base import BaseLLM from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.general_modules.retriever import SchedulerRetriever +from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.schemas.general_schemas import ( @@ -27,6 +33,25 @@ class TestGeneralScheduler(unittest.TestCase): + def _create_mock_auth_config(self): + """Create a mock AuthConfig for testing purposes.""" + # Create mock configs with valid test values + graph_db_config = GraphDBAuthConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test_password_123", # 8+ characters to pass validation + db_name="neo4j", + auto_create=True, + ) + + rabbitmq_config = RabbitMQConfig( + host_name="localhost", port=5672, user_name="guest", password="guest", virtual_host="/" + ) + + openai_config = OpenAIConfig(api_key="test_api_key_123", default_model="gpt-3.5-turbo") + + return AuthConfig(rabbitmq=rabbitmq_config, openai=openai_config, graph_db=graph_db_config) + def setUp(self): """Initialize test environment with mock objects and test scheduler instance.""" example_scheduler_config_path = ( @@ -43,6 +68,13 @@ def setUp(self): self.mem_cube.text_mem = self.tree_text_memory self.mem_cube.act_mem = MagicMock() + # Mock AuthConfig.from_local_env() to return our test config + mock_auth_config = self._create_mock_auth_config() + self.auth_config_patch = patch( + "memos.configs.mem_scheduler.AuthConfig.from_local_env", return_value=mock_auth_config + ) + self.auth_config_patch.start() + # Initialize general_modules with mock LLM self.scheduler.initialize_modules(chat_llm=self.llm, process_llm=self.llm) self.scheduler.mem_cube = self.mem_cube @@ -51,6 +83,10 @@ def setUp(self): self.scheduler.current_user_id = "test_user" self.scheduler.current_mem_cube_id = "test_cube" + def tearDown(self): + """Clean up patches.""" + self.auth_config_patch.stop() + def test_initialization(self): """Test that scheduler initializes with correct default values and handlers.""" # Verify handler registration diff --git a/tests/mem_scheduler/test_version_control.py b/tests/mem_scheduler/test_version_control.py new file mode 100644 index 00000000..efe2c6b7 --- /dev/null +++ b/tests/mem_scheduler/test_version_control.py @@ -0,0 +1,273 @@ +import os +import tempfile + +import pytest + +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager +from memos.mem_scheduler.orm_modules.monitor_models import DBManagerForMemoryMonitorManager +from memos.mem_scheduler.schemas.monitor_schemas import ( + MemoryMonitorItem, + MemoryMonitorManager, +) + + +class TestVersionControl: + """Test version control functionality""" + + @pytest.fixture + def temp_db(self): + """Create a temporary database for testing.""" + temp_dir = tempfile.mkdtemp() + db_path = os.path.join(temp_dir, "test_version_control.db") + yield db_path + # Cleanup + try: + if os.path.exists(db_path): + os.remove(db_path) + os.rmdir(temp_dir) + except (OSError, PermissionError): + pass + + @pytest.fixture + def memory_manager_obj(self): + """Create a MemoryMonitorManager object for testing""" + return MemoryMonitorManager( + user_id="test_user", + mem_cube_id="test_mem_cube", + memories=[ + MemoryMonitorItem( + item_id="test-item-1", + memory_text="Test memory 1", + tree_memory_item=None, + tree_memory_item_mapping_key="test_key_1", + keywords_score=0.8, + sorting_score=0.9, + importance_score=0.7, + recording_count=1, + ) + ], + ) + + def test_version_control_increment(self, temp_db, memory_manager_obj): + """Test that version_control increments correctly""" + engine = BaseDBManager.create_engine_from_db_path(temp_db) + manager = DBManagerForMemoryMonitorManager( + engine=engine, + user_id="test_user", + mem_cube_id="test_mem_cube", + obj=memory_manager_obj, + ) + + try: + # Test increment method + assert manager._increment_version_control("0") == "1" + assert manager._increment_version_control("255") == "0" # Should cycle back to 0 + assert manager._increment_version_control("100") == "101" + assert ( + manager._increment_version_control("invalid") == "0" + ) # Should handle invalid input + + finally: + manager.close() + + def test_new_record_has_version_zero(self, temp_db, memory_manager_obj): + """Test that new records start with version_control = "0" """ + engine = BaseDBManager.create_engine_from_db_path(temp_db) + manager = DBManagerForMemoryMonitorManager( + engine=engine, + user_id="test_user", + mem_cube_id="test_mem_cube", + obj=memory_manager_obj, + ) + + try: + # Save to database + manager.save_to_db(memory_manager_obj) + + # Check that last_version_control was set to "0" + assert manager.last_version_control == "0" + + # Load from database and verify version_control + loaded_obj = manager.load_from_db() + assert loaded_obj is not None + + # Check that the version was tracked + assert manager.last_version_control == "0" + + finally: + manager.close() + + def test_version_control_increments_on_save(self, temp_db, memory_manager_obj): + """Test that version_control increments when saving existing records""" + engine = BaseDBManager.create_engine_from_db_path(temp_db) + manager = DBManagerForMemoryMonitorManager( + engine=engine, + user_id="test_user", + mem_cube_id="test_mem_cube", + obj=memory_manager_obj, + ) + + try: + # First save - should create with version "0" + manager.save_to_db(memory_manager_obj) + assert manager.last_version_control == "0" + + # Second save - should increment to version "1" + manager.save_to_db(memory_manager_obj) + assert manager.last_version_control == "1" + + # Third save - should increment to version "2" + manager.save_to_db(memory_manager_obj) + assert manager.last_version_control == "2" + + finally: + manager.close() + + def test_sync_with_orm_version_control(self, temp_db, memory_manager_obj): + """Test version control behavior in sync_with_orm""" + engine = BaseDBManager.create_engine_from_db_path(temp_db) + manager = DBManagerForMemoryMonitorManager( + engine=engine, + user_id="test_user", + mem_cube_id="test_mem_cube", + obj=memory_manager_obj, + ) + + try: + # First sync - should create with version "0" + manager.sync_with_orm() + assert manager.last_version_control == "0" + + # Second sync with same object - should increment version because sync_with_orm always increments + manager.sync_with_orm() + assert ( + manager.last_version_control == "1" + ) # Should increment to "1" since sync_with_orm always increments + + # Third sync - should increment to version "2" + manager.sync_with_orm() + assert manager.last_version_control == "2" # Should increment to "2" + + # Simulate a change by creating a new object with different content + new_memory_manager = MemoryMonitorManager( + user_id="test_user", + mem_cube_id="test_mem_cube", + memories=[ + MemoryMonitorItem( + item_id="test-item-2", + memory_text="Test memory 2", + tree_memory_item=None, + tree_memory_item_mapping_key="test_key_2", + keywords_score=0.9, + sorting_score=0.8, + importance_score=0.6, + recording_count=2, + ) + ], + ) + + # Update the manager's object + manager.obj = new_memory_manager + + # Sync again - should increment version because object content changed + manager.sync_with_orm() + assert manager.last_version_control == "3" # Should increment to "3" + + finally: + manager.close() + + def test_version_control_cycles_correctly(self, temp_db, memory_manager_obj): + """Test that version_control cycles from 255 back to 0""" + engine = BaseDBManager.create_engine_from_db_path(temp_db) + manager = DBManagerForMemoryMonitorManager( + engine=engine, + user_id="test_user", + mem_cube_id="test_mem_cube", + obj=memory_manager_obj, + ) + + try: + # Test the increment method directly + assert manager._increment_version_control("255") == "0" + assert manager._increment_version_control("254") == "255" + assert manager._increment_version_control("0") == "1" + + finally: + manager.close() + + def test_load_from_db_updates_version_control(self, temp_db, memory_manager_obj): + """Test that load_from_db updates last_version_control correctly""" + engine = BaseDBManager.create_engine_from_db_path(temp_db) + manager = DBManagerForMemoryMonitorManager( + engine=engine, + user_id="test_user", + mem_cube_id="test_mem_cube", + obj=memory_manager_obj, + ) + + try: + # Save to database first + manager.save_to_db(memory_manager_obj) + assert manager.last_version_control == "0" + + # Create a new manager instance to load the data + load_manager = DBManagerForMemoryMonitorManager( + engine=engine, + user_id="test_user", + mem_cube_id="test_mem_cube", + ) + + # Load from database + loaded_obj = load_manager.load_from_db() + assert loaded_obj is not None + assert load_manager.last_version_control == "0" # Should be updated to loaded version + + load_manager.close() + + finally: + manager.close() + + def test_version_control_persistence_across_instances(self, temp_db, memory_manager_obj): + """Test that version control persists across different manager instances""" + engine = BaseDBManager.create_engine_from_db_path(temp_db) + + # First manager instance + manager1 = DBManagerForMemoryMonitorManager( + engine=engine, + user_id="test_user", + mem_cube_id="test_mem_cube", + obj=memory_manager_obj, + ) + + try: + # Save multiple times to increment version + manager1.save_to_db(memory_manager_obj) + assert manager1.last_version_control == "0" + + manager1.save_to_db(memory_manager_obj) + assert manager1.last_version_control == "1" + + manager1.save_to_db(memory_manager_obj) + assert manager1.last_version_control == "2" + + # Create second manager instance + manager2 = DBManagerForMemoryMonitorManager( + engine=engine, + user_id="test_user", + mem_cube_id="test_mem_cube", + obj=memory_manager_obj, + ) + + # Load should show the same version + loaded_obj = manager2.load_from_db() + assert loaded_obj is not None + assert manager2.last_version_control == "2" # Should match the last saved version + + # Save again should increment from the loaded version + manager2.save_to_db(memory_manager_obj) + assert manager2.last_version_control == "3" + + manager2.close() + + finally: + manager1.close()