diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py index a12f511e88afe..17516fbd57139 100644 --- a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py @@ -44,14 +44,35 @@ def test_simple_net_hybrid_strategy(self): ) -class TestSemiAutoParallelLlama2D(test_base.CommunicationTestDistBase): +class TestSemiAutoParallelLlama2DBase(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=4, timeout=400, nnode=1) self._default_envs = {"dp": "2", "mp": "2", "pp": "1", "acc_step": "2"} self._changeable_envs = { "backend": ["gpu"], - "use_sp": ["true", "false"], - "recompute": ["true", "false"], + "use_sp": ["false"], + "recompute": ["false"], + } + + def test_simple_net_hybrid_strategy(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_llama.py", + user_defined_envs=envs, + ) + + +class TestSemiAutoParallelLlama2DTest(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=4, timeout=400, nnode=1) + self._default_envs = {"dp": "2", "mp": "2", "pp": "1", "acc_step": "2"} + self._changeable_envs = { + "backend": ["gpu"], + "use_sp": ["true"], + "recompute": ["true"], "recompute_granularity": ["full", "full_attn", "core_attn"], } @@ -66,15 +87,37 @@ def test_simple_net_hybrid_strategy(self): ) -class TestSemiAutoParallelLlama3D(test_base.CommunicationTestDistBase): +class TestSemiAutoParallelLlama3DBase(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=8, timeout=200, nnode=1) self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"} self._changeable_envs = { "backend": ["gpu"], - "use_sp": ["true", "false"], - "use_param_group": ["false", "true"], - "recompute": ["true", "false"], + "use_sp": ["false"], + "use_param_group": ["false"], + "recompute": ["false"], + } + + def test_simple_net_hybrid_strategy(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_llama.py", + user_defined_envs=envs, + ) + + +class TestSemiAutoParallelLlama3DTest(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=8, timeout=200, nnode=1) + self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"} + self._changeable_envs = { + "backend": ["gpu"], + "use_sp": ["true"], + "use_param_group": ["true"], + "recompute": ["true"], "recompute_granularity": ["full", "full_attn", "core_attn"], } @@ -89,7 +132,34 @@ def test_simple_net_hybrid_strategy(self): ) -class TestSemiAutoParallelLlamaACC(test_base.CommunicationTestDistBase): +class TestSemiAutoParallelLlamaACCBase(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=8, timeout=200, nnode=1) + self._default_envs = { + "dp": "2", + "mp": "2", + "pp": "2", + "acc_step": "1", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + } + self._changeable_envs = { + "backend": ["gpu"], + "recompute": ["false"], + } + + def test_simple_net_hybrid_strategy_acc(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_llama.py", + user_defined_envs=envs, + ) + + +class TestSemiAutoParallelLlamaACCTest(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=8, timeout=200, nnode=1) self._default_envs = { @@ -102,7 +172,7 @@ def setUp(self): } self._changeable_envs = { "backend": ["gpu"], - "recompute": ["true", "false"], + "recompute": ["true"], "recompute_granularity": ["full", "full_attn", "core_attn"], }