From fc722d85eef6644f7593dd26c7fd55a56615595b Mon Sep 17 00:00:00 2001 From: David Miguel Lozano Date: Mon, 10 Feb 2025 17:59:05 +0100 Subject: [PATCH] fix: RunnableMap doesn't invoke multiple Runnables in parallel (#649) --- .../langchain_core/lib/src/runnables/map.dart | 10 +++---- .../test/runnables/map_test.dart | 28 +++++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/packages/langchain_core/lib/src/runnables/map.dart b/packages/langchain_core/lib/src/runnables/map.dart index 0b3cb925..4634b2c9 100644 --- a/packages/langchain_core/lib/src/runnables/map.dart +++ b/packages/langchain_core/lib/src/runnables/map.dart @@ -62,16 +62,16 @@ class RunnableMap final RunInput input, { final RunnableOptions? options, }) async { - final output = {}; - - await Future.forEach(steps.entries, (final entry) async { - output[entry.key] = await entry.value.invoke( + final futures = steps.entries.map((entry) async { + final result = await entry.value.invoke( input, options: entry.value.getCompatibleOptions(options), ); + return MapEntry(entry.key, result); }); - return output; + final results = await Future.wait(futures); + return Map.fromEntries(results); } @override diff --git a/packages/langchain_core/test/runnables/map_test.dart b/packages/langchain_core/test/runnables/map_test.dart index e65dc73a..351fa7fc 100644 --- a/packages/langchain_core/test/runnables/map_test.dart +++ b/packages/langchain_core/test/runnables/map_test.dart @@ -25,6 +25,34 @@ void main() { ); }); + test('RunnableMap runs tasks in parallel', () async { + final longTask = Runnable.fromFunction( + invoke: (_, __) async { + await Future.delayed(const Duration(seconds: 2)); + return 'long'; + }, + ); + final shortTask = Runnable.fromFunction( + invoke: (_, __) async { + await Future.delayed(const Duration(seconds: 1)); + return 'short'; + }, + ); + + final chain = Runnable.fromMap({ + 'long': longTask, + 'short': shortTask, + }); + + final stopwatch = Stopwatch()..start(); + final result = await chain.invoke({}); + stopwatch.stop(); + + expect(stopwatch.elapsed, lessThan(const Duration(seconds: 3))); + expect(result['long'], 'long'); + expect(result['short'], 'short'); + }); + test('Streaming RunnableMap', () async { final prompt1 = PromptTemplate.fromTemplate('Hello {input}!'); final prompt2 = PromptTemplate.fromTemplate('Bye {input}!');