@@ -581,4 +581,324 @@ public void testUpdateModelDeployStatusAndTriggerOnNodesAction_whenMLTaskManager
581581 verify (mlTaskManager ).updateMLTask (anyString (), any (), anyMap (), anyLong (), anyBoolean ());
582582 }
583583
584+ public void testDeployRemoteModel_success () {
585+ MLModel mlModel = mock (MLModel .class );
586+ when (mlModel .getModelId ()).thenReturn ("test-model-id" );
587+ when (mlModel .getTenantId ()).thenReturn ("test-tenant" );
588+ when (mlModel .getModelContentHash ()).thenReturn ("test-hash" );
589+ when (mlModel .getIsHidden ()).thenReturn (false );
590+
591+ MLTask mlTask = mock (MLTask .class );
592+ when (mlTask .getTaskId ()).thenReturn ("test-task-id" );
593+
594+ DiscoveryNode node = mock (DiscoveryNode .class );
595+ when (node .getId ()).thenReturn ("node1" );
596+ List <DiscoveryNode > nodes = List .of (node );
597+
598+ doAnswer (invocation -> {
599+ ActionListener <UpdateResponse > listener = invocation .getArgument (3 );
600+ listener .onResponse (mock (UpdateResponse .class ));
601+ return null ;
602+ }).when (mlModelManager ).updateModel (anyString (), anyString (), anyMap (), any ());
603+
604+ doAnswer (invocation -> {
605+ ActionListener <MLDeployModelNodesResponse > listener = invocation .getArgument (2 );
606+ listener .onResponse (mock (MLDeployModelNodesResponse .class ));
607+ return null ;
608+ }).when (client ).execute (any (), any (), any ());
609+
610+ when (mlTaskManager .contains (anyString ())).thenReturn (true );
611+
612+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
613+ transportDeployModelAction .deployRemoteModel (mlModel , mlTask , "local-node" , nodes , true , listener );
614+
615+ verify (listener ).onResponse (any (MLDeployModelResponse .class ));
616+ }
617+
618+ public void testDeployRemoteModel_failure () {
619+ MLModel mlModel = mock (MLModel .class );
620+ when (mlModel .getModelId ()).thenReturn ("test-model-id" );
621+ when (mlModel .getTenantId ()).thenReturn ("test-tenant" );
622+ when (mlModel .getModelContentHash ()).thenReturn ("test-hash" );
623+ when (mlModel .getIsHidden ()).thenReturn (false );
624+
625+ MLTask mlTask = mock (MLTask .class );
626+ when (mlTask .getTaskId ()).thenReturn ("test-task-id" );
627+
628+ DiscoveryNode node = mock (DiscoveryNode .class );
629+ when (node .getId ()).thenReturn ("node1" );
630+ List <DiscoveryNode > nodes = List .of (node );
631+
632+ doAnswer (invocation -> {
633+ ActionListener <UpdateResponse > listener = invocation .getArgument (3 );
634+ listener .onFailure (new RuntimeException ("Update failed" ));
635+ return null ;
636+ }).when (mlModelManager ).updateModel (anyString (), anyString (), anyMap (), any ());
637+
638+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
639+ transportDeployModelAction .deployRemoteModel (mlModel , mlTask , "local-node" , nodes , true , listener );
640+
641+ verify (listener ).onFailure (any (RuntimeException .class ));
642+ }
643+
644+ public void testDoExecute_deployToAllNodes_false () {
645+ MLModel mlModel = mock (MLModel .class );
646+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
647+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
648+ when (mlModel .getIsHidden ()).thenReturn (false );
649+
650+ // Use the existing mlDeployModelRequest but override specific nodes
651+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (new String [] { "node1" , "node2" });
652+
653+ doAnswer (invocation -> {
654+ ActionListener <MLModel > listener = invocation .getArgument (4 );
655+ listener .onResponse (mlModel );
656+ return null ;
657+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
658+
659+ // Set up multiple eligible nodes
660+ DiscoveryNode node1 = mock (DiscoveryNode .class );
661+ DiscoveryNode node2 = mock (DiscoveryNode .class );
662+ when (node1 .getId ()).thenReturn ("node1" );
663+ when (node2 .getId ()).thenReturn ("node2" );
664+ DiscoveryNode [] nodes = { node1 , node2 };
665+ when (nodeFilter .getEligibleNodes (any ())).thenReturn (nodes );
666+ when (mlModelManager .getWorkerNodes (anyString (), any ())).thenReturn (null );
667+
668+ IndexResponse indexResponse = mock (IndexResponse .class );
669+ when (indexResponse .getId ()).thenReturn ("task-id" );
670+ doAnswer (invocation -> {
671+ ActionListener <IndexResponse > listener = invocation .getArgument (1 );
672+ listener .onResponse (indexResponse );
673+ return null ;
674+ }).when (mlTaskManager ).createMLTask (any (MLTask .class ), any ());
675+
676+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
677+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
678+
679+ verify (listener ).onResponse (any (MLDeployModelResponse .class ));
680+ }
681+
682+ public void testDoExecute_workerNodesConflict () {
683+ MLModel mlModel = mock (MLModel .class );
684+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
685+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
686+ when (mlModel .getIsHidden ()).thenReturn (false );
687+
688+ // Use the existing mlDeployModelRequest but override specific nodes
689+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (new String [] { "node1" });
690+
691+ doAnswer (invocation -> {
692+ ActionListener <MLModel > listener = invocation .getArgument (4 );
693+ listener .onResponse (mlModel );
694+ return null ;
695+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
696+
697+ // Set up eligible nodes
698+ DiscoveryNode node1 = mock (DiscoveryNode .class );
699+ DiscoveryNode node2 = mock (DiscoveryNode .class );
700+ when (node1 .getId ()).thenReturn ("node1" );
701+ when (node2 .getId ()).thenReturn ("node2" );
702+ DiscoveryNode [] nodes = { node1 , node2 };
703+ when (nodeFilter .getEligibleNodes (any ())).thenReturn (nodes );
704+
705+ // Set up worker nodes conflict - model is already deployed on node2 but target is node1
706+ when (mlModelManager .getWorkerNodes (anyString (), any ())).thenReturn (new String [] { "node2" });
707+
708+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
709+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
710+
711+ verify (listener ).onFailure (any (IllegalArgumentException .class ));
712+ }
713+
714+ public void testDoExecute_noEligibleNodes () {
715+ MLModel mlModel = mock (MLModel .class );
716+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
717+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
718+ when (mlModel .getIsHidden ()).thenReturn (false );
719+
720+ // Use the existing mlDeployModelRequest but override to request non-existent node
721+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (new String [] { "non-existent-node" });
722+
723+ doAnswer (invocation -> {
724+ ActionListener <MLModel > listener = invocation .getArgument (4 );
725+ listener .onResponse (mlModel );
726+ return null ;
727+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
728+
729+ // Set up eligible nodes that don't match the requested nodes
730+ DiscoveryNode existingNode = mock (DiscoveryNode .class );
731+ when (existingNode .getId ()).thenReturn ("existing-node" );
732+ DiscoveryNode [] nodes = { existingNode };
733+ when (nodeFilter .getEligibleNodes (any ())).thenReturn (nodes );
734+
735+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
736+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
737+
738+ verify (listener ).onFailure (any (IllegalArgumentException .class ));
739+ }
740+
741+ public void testDoExecute_deployToAllNodes_true () {
742+ MLModel mlModel = mock (MLModel .class );
743+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
744+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
745+ when (mlModel .getIsHidden ()).thenReturn (false );
746+
747+ // Use null or empty array to trigger deployToAllNodes = true
748+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (null );
749+
750+ doAnswer (invocation -> {
751+ ActionListener <MLModel > listener = invocation .getArgument (4 );
752+ listener .onResponse (mlModel );
753+ return null ;
754+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
755+
756+ // Set up eligible nodes
757+ DiscoveryNode node1 = mock (DiscoveryNode .class );
758+ DiscoveryNode node2 = mock (DiscoveryNode .class );
759+ when (node1 .getId ()).thenReturn ("node1" );
760+ when (node2 .getId ()).thenReturn ("node2" );
761+ DiscoveryNode [] nodes = { node1 , node2 };
762+ when (nodeFilter .getEligibleNodes (any ())).thenReturn (nodes );
763+
764+ IndexResponse indexResponse = mock (IndexResponse .class );
765+ when (indexResponse .getId ()).thenReturn ("task-id" );
766+ doAnswer (invocation -> {
767+ ActionListener <IndexResponse > listener = invocation .getArgument (1 );
768+ listener .onResponse (indexResponse );
769+ return null ;
770+ }).when (mlTaskManager ).createMLTask (any (MLTask .class ), any ());
771+
772+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
773+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
774+
775+ verify (listener ).onResponse (any (MLDeployModelResponse .class ));
776+ }
777+
778+ public void testDoExecute_accessControlFailure () {
779+ MLModel mlModel = mock (MLModel .class );
780+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
781+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
782+ when (mlModel .getIsHidden ()).thenReturn (false );
783+
784+ doAnswer (invocation -> {
785+ ActionListener <MLModel > listener = invocation .getArgument (4 );
786+ listener .onResponse (mlModel );
787+ return null ;
788+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
789+
790+ // Mock access control to return false (no access)
791+ doAnswer (invocation -> {
792+ ActionListener <Boolean > listener = invocation .getArgument (3 );
793+ listener .onResponse (false );
794+ return null ;
795+ }).when (modelAccessControlHelper ).validateModelGroupAccess (any (), anyString (), any (), any ());
796+
797+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
798+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
799+
800+ verify (listener ).onFailure (any (OpenSearchStatusException .class ));
801+ }
802+
803+ public void testDoExecute_hiddenModelNonSuperAdmin () {
804+ MLModel mlModel = mock (MLModel .class );
805+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
806+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
807+ when (mlModel .getIsHidden ()).thenReturn (true );
808+
809+ doAnswer (invocation -> {
810+ ActionListener <MLModel > listener = invocation .getArgument (4 );
811+ listener .onResponse (mlModel );
812+ return null ;
813+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
814+
815+ // Mock the isSuperAdminUserWrapper to return false (not super admin)
816+ TransportDeployModelAction spyAction = spy (transportDeployModelAction );
817+ doReturn (false ).when (spyAction ).isSuperAdminUserWrapper (any (), any ());
818+
819+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
820+ spyAction .doExecute (null , mlDeployModelRequest , listener );
821+
822+ verify (listener ).onFailure (any (OpenSearchStatusException .class ));
823+ }
824+
825+ public void testDoExecute_taskManagerNotContains () {
826+ MLModel mlModel = mock (MLModel .class );
827+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
828+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
829+ when (mlModel .getIsHidden ()).thenReturn (false );
830+
831+ doAnswer (invocation -> {
832+ ActionListener <MLModel > listener = invocation .getArgument (4 );
833+ listener .onResponse (mlModel );
834+ return null ;
835+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
836+
837+ IndexResponse indexResponse = mock (IndexResponse .class );
838+ when (indexResponse .getId ()).thenReturn ("task-id" );
839+ doAnswer (invocation -> {
840+ ActionListener <IndexResponse > listener = invocation .getArgument (1 );
841+ listener .onResponse (indexResponse );
842+ return null ;
843+ }).when (mlTaskManager ).createMLTask (any (MLTask .class ), any ());
844+
845+ // Mock mlTaskManager.contains to return false
846+ when (mlTaskManager .contains (anyString ())).thenReturn (false );
847+
848+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
849+ transportDeployModelAction .doExecute (null , mlDeployModelRequest , listener );
850+
851+ verify (listener ).onResponse (any (MLDeployModelResponse .class ));
852+ }
853+
854+ public void testDoExecute_customDeploymentNotAllowed () {
855+ // Override the settings to disable custom deployment plan
856+ Settings restrictiveSettings = Settings .builder ().put (ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN .getKey (), false ).build ();
857+ ClusterSettings restrictiveClusterSettings = new ClusterSettings (
858+ restrictiveSettings ,
859+ new HashSet <>(Arrays .asList (ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN ))
860+ );
861+ when (clusterService .getClusterSettings ()).thenReturn (restrictiveClusterSettings );
862+ when (clusterService .getSettings ()).thenReturn (restrictiveSettings );
863+
864+ // Create a new instance with restrictive settings
865+ TransportDeployModelAction restrictiveAction = new TransportDeployModelAction (
866+ transportService ,
867+ actionFilters ,
868+ modelHelper ,
869+ mlTaskManager ,
870+ clusterService ,
871+ threadPool ,
872+ client ,
873+ sdkClient ,
874+ namedXContentRegistry ,
875+ nodeFilter ,
876+ mlTaskDispatcher ,
877+ mlModelManager ,
878+ mlStats ,
879+ restrictiveSettings ,
880+ modelAccessControlHelper ,
881+ mlFeatureEnabledSetting
882+ );
883+
884+ MLModel mlModel = mock (MLModel .class );
885+ when (mlModel .getAlgorithm ()).thenReturn (FunctionName .ANOMALY_LOCALIZATION );
886+ when (mlModel .getModelGroupId ()).thenReturn ("test-group-id" );
887+ when (mlModel .getIsHidden ()).thenReturn (false );
888+
889+ // Set specific nodes (not deploy to all)
890+ when (mlDeployModelRequest .getModelNodeIds ()).thenReturn (new String [] { "node1" });
891+
892+ doAnswer (invocation -> {
893+ ActionListener <MLModel > listener = invocation .getArgument (4 );
894+ listener .onResponse (mlModel );
895+ return null ;
896+ }).when (mlModelManager ).getModel (anyString (), any (), isNull (), any (String [].class ), any ());
897+
898+ ActionListener <MLDeployModelResponse > listener = mock (ActionListener .class );
899+ restrictiveAction .doExecute (null , mlDeployModelRequest , listener );
900+
901+ verify (listener ).onFailure (any (IllegalArgumentException .class ));
902+ }
903+
584904}
0 commit comments