-
Notifications
You must be signed in to change notification settings - Fork 30
Multi-Domain RL Training #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 65 commits
cac78d7
df1d846
9735130
bb5e5ca
a1a02bf
e4d0bc4
da43cbc
5b18001
9af7329
599b510
32eb5b8
efaec65
1220f6d
158b2ea
fe8e728
6f9c5cc
455ed42
795e490
68072b1
981cb74
f7e6946
73ca9d1
38ff188
2c5ebfd
40cf648
2c74b77
cdfe57b
a3c4106
7eef15e
cc091ac
7664773
120ba7b
f8d147e
62ad5fb
2d22d5e
52fcb56
9da8f04
3368e69
2831252
cb916c1
a6ad805
2243897
0cbc542
6888e05
4adcd81
a5a6e44
dca8a43
84b6587
8bbca61
1937578
69b5154
fd2fc3b
b139560
8b5c159
b876adb
904c80e
e4017d9
d7935d0
8e2e7b3
5e10988
dac01c1
eb3bacf
f5093bf
27a2a6d
5aea032
02f9294
f6c128c
60169b2
a3de18e
ba360fa
f405321
bba110d
8f60aa2
67ffd60
7ea8744
182ee6e
6bc73eb
63c3035
fa9a9bc
5e9b037
2677ad9
48afae1
00fd6cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| defaults: | ||
| - base | ||
| - _self_ | ||
|
|
||
| actor: | ||
| rollout_policy: pipelinerl.domains.coding.generate_coding_rollout | ||
| system_prompt: "" | ||
| task_template: |- | ||
| {task} | ||
| task_prompt: "" | ||
| ensure_boxed_answers: false | ||
|
|
||
| coding_time_limit_s: 15.0 | ||
| coding_per_test_timeout_s: 10.0 | ||
| coding_memory_limit_bytes: 1073741824 | ||
| coding_compile_timeout_s: 10.0 | ||
| coding_sandbox_url: ${oc.env:CODING_SANDBOX_URL, "http://sandbox:8080/run_code"} | ||
|
|
||
| dataset_loader: pipelinerl.domains.coding.dataset.load_problems | ||
| dataset_loader_params: | ||
| dataset_id: ServiceNow-AI/mixed-training-text-datasets | ||
| dataset_config: 80k-if-math-coding-fncalling-stem | ||
| split_ratios: | ||
| train: 0.9 | ||
| validation: 0.05 | ||
| test: 0.05 | ||
| allowed_call_types: | ||
| - assert | ||
| - std | ||
| max_examples_per_split: 2048 | ||
| trust_remote_code: true | ||
| huggingface_token: ${oc.env:CODING_HF_TOKEN, null} | ||
|
|
||
| train_dataset_names: | ||
| - coding@train | ||
|
|
||
| test_dataset_names: | ||
| - coding@validation | ||
|
|
||
| environments: | ||
| - key: coding | ||
| mode: remote | ||
| _target_: pipelinerl.domains.coding.CodingSandboxEnvironment | ||
| sandbox_url: ${actor.coding_sandbox_url} | ||
| compile_timeout_s: ${actor.coding_compile_timeout_s} | ||
| run_timeout_s: ${actor.coding_per_test_timeout_s} | ||
| request_timeout_s: ${actor.coding_time_limit_s} | ||
| memory_limit_bytes: ${actor.coding_memory_limit_bytes} | ||
|
|
||
| environment_key: coding |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| defaults: | ||
| - base | ||
| - domain_rollouts: base | ||
| - override rewards: success_and_format | ||
| - _self_ | ||
|
|
||
| actor: | ||
| rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout | ||
| llm_max_rollouts: 2 | ||
| rollout_workers: 1 | ||
| domain_rollouts: | ||
| math: ${domain_rollouts.math} | ||
| guessing: ${domain_rollouts.guessing} | ||
| coding: ${domain_rollouts.coding} | ||
|
|
||
| dataset_loader: pipelinerl.domains.multidomain.load_problems | ||
| train_dataset_names: | ||
| - math_debug | ||
| - guessing_debug | ||
| - coding_debug | ||
| test_dataset_names: | ||
| - math_debug | ||
| - coding_debug | ||
|
|
||
| environment: null | ||
| environment_key: null | ||
|
|
||
| world: | ||
| env_replicas_per_actor: 0 | ||
| environment_mode: embedded |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| # Domain mix presets | ||
|
|
||
| Hydra group `domain_mix` stores reusable presets for `actor.domain_mix`. | ||
|
|
||
| Usage examples: | ||
|
|
||
| ``` | ||
| python main.py --config-name multi_domain/base +domain_mix=math_coding_70_30 | ||
| python main.py --config-name multi_domain/base +domain_mix=balanced | ||
| ``` | ||
|
|
||
| Override or extend these presets by creating new files under `conf/domain_mix/`. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| # @package actor.domain_mix | ||
|
|
||
| math: 1.0 | ||
ollmer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| guessing: 1.0 | ||
| counting: 1.0 | ||
| chartqa: 1.0 | ||
| miniwob: 1.0 | ||
| coding: 1.0 | ||
| fn_calling: 1.0 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| # @package actor.domain_mix | ||
|
|
||
| math: 0.3 | ||
| coding: 0.7 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # @package actor.domain_mix | ||
|
|
||
| math: 0.4 | ||
| coding: 0.3 | ||
| fn_calling: 0.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| # @package actor.domain_mix | ||
|
|
||
| math: 0.7 | ||
| coding: 0.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # Mapping between domain identifiers and rollout callables. | ||
| math: pipelinerl.domains.math.generate_math_rollout | ||
| guessing: pipelinerl.domains.guessing.generate_guessing_rollout | ||
| counting: pipelinerl.domains.counting.generate_counting_rollout | ||
| miniwob: pipelinerl.domains.miniwob.rollouts.generate_miniwob_rollout | ||
| chartqa: pipelinerl.domains.chartqa.generate_chartqa_rollout | ||
| coding: pipelinerl.domains.coding.generate_coding_rollout | ||
| fn_calling: pipelinerl.domains.fn_calling.generate_fn_calling_rollout |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| defaults: | ||
| - base | ||
| - _self_ | ||
|
|
||
| actor: | ||
| rollout_policy: pipelinerl.domains.fn_calling.generate_fn_calling_rollout | ||
| system_prompt: "" | ||
| task_template: "{task}" | ||
| task_prompt: "" | ||
| ensure_boxed_answers: false | ||
|
|
||
| dataset_loader: pipelinerl.domains.fn_calling.dataset.load_problems | ||
| dataset_loader_params: | ||
| dataset_id: ServiceNow-AI/mixed-training-text-datasets | ||
| dataset_config: 80k-if-math-coding-fncalling-stem | ||
| split_ratios: | ||
| train: 0.9 | ||
| validation: 0.05 | ||
| test: 0.05 | ||
| allowed_call_types: [] | ||
| max_examples_per_split: 2048 | ||
| trust_remote_code: true | ||
| huggingface_token: ${oc.env:CODING_HF_TOKEN, null} | ||
|
|
||
| train_dataset_names: | ||
| - fn_calling@train | ||
|
|
||
| test_dataset_names: | ||
| - fn_calling@validation | ||
|
|
||
| environments: | ||
| - key: fn_calling | ||
| mode: remote | ||
| _target_: pipelinerl.domains.fn_calling.AgenticToolsEnvironment | ||
|
|
||
| environment_key: fn_calling |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| defaults: | ||
| - base | ||
| - /domain_rollouts@domain_rollouts: base | ||
| - domain_mix: math_coding_70_30 | ||
| - _self_ | ||
|
|
||
| actor: | ||
| rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout | ||
| system_prompt: "" | ||
| task_template: |- | ||
| {task} | ||
| task_prompt: "" | ||
| ensure_boxed_answers: false | ||
| domain_rollouts: | ||
| math: ${domain_rollouts.math} | ||
| coding: ${domain_rollouts.coding} | ||
| coding_time_limit_s: 15.0 | ||
| coding_per_test_timeout_s: 10.0 | ||
| coding_memory_limit_bytes: 1073741824 | ||
| coding_compile_timeout_s: 10.0 | ||
| coding_sandbox_url: ${oc.env:CODING_SANDBOX_URL, "http://sandbox:8080/run_code"} | ||
|
|
||
| dataset_loader: pipelinerl.domains.multidomain.loader.load_datasets | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have multiple dataloaders at once so we load different datasets for different domains in the same exp?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it's a proportional concatenation
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so how can we define a multiple dataset_loader functions in a single config?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in your config you could do (if you wanted to do math, coding and agentic_fn_calling): defaults:
- base
- multi_domain: base # inits multidomain loader and dispatcher
- domain_mix: main_mix # Or inline the mix as showed below
actor:
domain_mix:
math: 0.4
coding: 0.3
fn_calling: 0.3
train_dataset_names:
- math::open_reasoner_zero_57k
- coding::coding@train
- fn_calling::fn_calling@train
test_dataset_names:
- math::math_500
- coding::coding@validation
- fn_calling::fn_calling@validation |
||
| dataset_loader_params: | ||
| per_domain_params: | ||
| coding: | ||
| dataset_id: ServiceNow-AI/mixed-training-text-datasets | ||
| dataset_config: 80k-if-math-coding-fncalling-stem | ||
| split_ratios: | ||
| train: 0.9 | ||
| validation: 0.05 | ||
| test: 0.05 | ||
| allowed_call_types: | ||
| - assert | ||
| - std | ||
| max_examples_per_split: 2048 | ||
| trust_remote_code: true | ||
| huggingface_token: ${oc.env:CODING_HF_TOKEN, null} | ||
|
|
||
| environments: | ||
| - key: math | ||
| mode: remote | ||
| replicas_per_actor: ${world.env_replicas_per_actor} | ||
| _target_: pipelinerl.domains.math.MathEnvironment | ||
| - key: coding | ||
| mode: remote | ||
| replicas_per_actor: ${world.env_replicas_per_actor} | ||
| _target_: pipelinerl.domains.coding.CodingSandboxEnvironment | ||
| sandbox_url: ${actor.coding_sandbox_url} | ||
| compile_timeout_s: ${actor.coding_compile_timeout_s} | ||
| run_timeout_s: ${actor.coding_per_test_timeout_s} | ||
| request_timeout_s: ${actor.coding_time_limit_s} | ||
| memory_limit_bytes: ${actor.coding_memory_limit_bytes} | ||
|
|
||
| environment_key: null | ||
|
|
||
| world: | ||
| env_replicas_per_actor: 1 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # @package _global_ | ||
| defaults: | ||
| - /domain_rollouts@domain_rollouts: base | ||
| - domain_mix: null | ||
|
|
||
| actor: | ||
| rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout | ||
| system_prompt: "" | ||
| task_template: |- | ||
| {task} | ||
| task_prompt: "" | ||
| ensure_boxed_answers: false | ||
| domain_mix: null | ||
| domain_rollouts: | ||
| math: ${domain_rollouts.math} | ||
| guessing: ${domain_rollouts.guessing} | ||
| counting: ${domain_rollouts.counting} | ||
| chartqa: ${domain_rollouts.chartqa} | ||
| miniwob: ${domain_rollouts.miniwob} | ||
| coding: ${domain_rollouts.coding} | ||
| fn_calling: ${domain_rollouts.fn_calling} | ||
| coding_time_limit_s: 15.0 | ||
| coding_per_test_timeout_s: 10.0 | ||
| coding_memory_limit_bytes: 1073741824 | ||
| coding_compile_timeout_s: 10.0 | ||
| coding_sandbox_url: ${oc.env:CODING_SANDBOX_URL, "http://sandbox:8080/run_code"} | ||
|
|
||
| dataset_loader: pipelinerl.domains.multidomain.loader.load_datasets | ||
| dataset_loader_params: | ||
| per_domain_params: | ||
| coding: | ||
| dataset_id: ServiceNow-AI/mixed-training-text-datasets | ||
| dataset_config: 80k-if-math-coding-fncalling-stem | ||
| split_ratios: | ||
| train: 0.9 | ||
| validation: 0.05 | ||
| test: 0.05 | ||
| allowed_call_types: | ||
| - assert | ||
| - std | ||
| max_examples_per_split: 2048 | ||
| trust_remote_code: true | ||
| huggingface_token: ${oc.env:CODING_HF_TOKEN, null} | ||
| fn_calling: | ||
| dataset_id: ServiceNow-AI/mixed-training-text-datasets | ||
| dataset_config: 80k-if-math-coding-fncalling-stem | ||
| split_ratios: | ||
| train: 0.9 | ||
| validation: 0.05 | ||
| test: 0.05 | ||
| allowed_call_types: [] | ||
| max_examples_per_split: 2048 | ||
| trust_remote_code: true | ||
| huggingface_token: ${oc.env:CODING_HF_TOKEN, null} | ||
|
|
||
| environments: | ||
| - key: math | ||
| mode: remote | ||
| replicas_per_actor: ${world.env_replicas_per_actor} | ||
| _target_: pipelinerl.domains.math.MathEnvironment | ||
| - key: coding | ||
| mode: remote | ||
| replicas_per_actor: ${world.env_replicas_per_actor} | ||
| _target_: pipelinerl.domains.coding.CodingSandboxEnvironment | ||
| sandbox_url: ${actor.coding_sandbox_url} | ||
| compile_timeout_s: ${actor.coding_compile_timeout_s} | ||
| run_timeout_s: ${actor.coding_per_test_timeout_s} | ||
| request_timeout_s: ${actor.coding_time_limit_s} | ||
| memory_limit_bytes: ${actor.coding_memory_limit_bytes} | ||
| - key: fn_calling | ||
| mode: remote | ||
| replicas_per_actor: ${world.env_replicas_per_actor} | ||
| _target_: pipelinerl.domains.fn_calling.AgenticToolsEnvironment | ||
| max_workers: 4 | ||
|
|
||
| environment_key: null | ||
|
|
||
| world: | ||
| env_replicas_per_actor: 1 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| defaults: | ||
| - base | ||
| - domain_mix: main_mix | ||
| - _self_ | ||
|
|
||
| actor: | ||
| domain_rollouts: | ||
| math: ${domain_rollouts.math} | ||
| coding: ${domain_rollouts.coding} | ||
| fn_calling: ${domain_rollouts.fn_calling} |
Uh oh!
There was an error while loading. Please reload this page.