diff --git a/frame/collect_on_comm.c b/frame/collect_on_comm.c index 15d2c5ef2e..964008533a 100644 --- a/frame/collect_on_comm.c +++ b/frame/collect_on_comm.c @@ -36,11 +36,11 @@ # endif #endif - + int col_on_comm ( int *, int *, void *, int *, void *, int *, int); int dst_on_comm ( int *, int *, void *, int *, void *, int *, int); -void +void COLLECT_ON_COMM ( int * comm, int * typesize , void * inbuf, int *ninbuf , void * outbuf, int * noutbuf ) { @@ -67,8 +67,9 @@ col_on_comm ( int * Fcomm, int * typesize , int *displace ; int noutbuf_loc ; int root_task ; + MPI_Datatype dtype; + int ierr = -1; MPI_Comm *comm, dummy_comm ; - int ierr ; comm = &dummy_comm ; *comm = MPI_Comm_f2c( *Fcomm ) ; @@ -90,28 +91,45 @@ col_on_comm ( int * Fcomm, int * typesize , for ( p = 1 , displace[0] = 0 , noutbuf_loc = recvcounts[0] ; p < ntasks ; p++ ) { displace[p] = displace[p-1]+recvcounts[p-1] ; noutbuf_loc = noutbuf_loc + recvcounts[p] ; + + /* check for overflow: displace is the partial sum of recvcounts, which can overflow for large problems. */ + if (displace[p] < 0) { +#ifndef MS_SUA + fprintf(stderr,"%s %d buffer offset overflow!!\n",__FILE__,__LINE__) ; + fprintf(stderr," ---> p = %d,\n ---> displace[%d] = %d,\n ---> typesize = %d\n", + p, p, displace[p], *typesize); +#endif + MPI_Abort(MPI_COMM_WORLD,1) ; + } } if ( noutbuf_loc > * noutbuf ) { #ifndef MS_SUA fprintf(stderr,"FATAL ERROR: collect_on_comm: noutbuf_loc (%d) > noutbuf (%d)\n", - noutbuf_loc , * noutbuf ) ; + noutbuf_loc , * noutbuf ) ; fprintf(stderr,"WILL NOT perform the collection operation\n") ; #endif MPI_Abort(MPI_COMM_WORLD,1) ; } - /* multiply everything by the size of the type */ - for ( p = 0 ; p < ntasks ; p++ ) { - displace[p] *= *typesize ; - recvcounts[p] *= *typesize ; + } + + /* handle different sized data types appropriately. */ + ierr = MPI_Type_match_size (MPI_TYPECLASS_REAL, *typesize, &dtype); + if (MPI_SUCCESS != ierr) { + ierr = MPI_Type_match_size (MPI_TYPECLASS_INTEGER, *typesize, &dtype); + if (MPI_SUCCESS != ierr) { +#ifndef MS_SUA + fprintf(stderr,"%s %d FATAL ERROR: unhandled typesize = %d!!\n", __FILE__,__LINE__,*typesize) ; +#endif + MPI_Abort(MPI_COMM_WORLD,1) ; } } - ierr = MPI_Gatherv( inbuf , *ninbuf * *typesize , MPI_CHAR , - outbuf , recvcounts , displace, MPI_CHAR , - root_task , *comm ) ; + ierr = MPI_Gatherv( inbuf , *ninbuf, dtype, + outbuf , recvcounts , displace, dtype, + root_task , *comm ) ; #ifndef MS_SUA if ( ierr != 0 ) fprintf(stderr,"%s %d MPI_Gatherv returns %d\n",__FILE__,__LINE__,ierr ) ; #endif @@ -152,6 +170,8 @@ dst_on_comm ( int * Fcomm, int * typesize , int *displace ; int noutbuf_loc ; int root_task ; + MPI_Datatype dtype; + int ierr = -1; MPI_Comm *comm, dummy_comm ; comm = &dummy_comm ; @@ -171,18 +191,34 @@ dst_on_comm ( int * Fcomm, int * typesize , for ( p = 1 , displace[0] = 0 , noutbuf_loc = sendcounts[0] ; p < ntasks ; p++ ) { displace[p] = displace[p-1]+sendcounts[p-1] ; noutbuf_loc = noutbuf_loc + sendcounts[p] ; + + /* check for overflow: displace is the partial sum of sendcounts, which can overflow for large problems. */ + if ( (displace[p] < 0) || (noutbuf_loc < 0) ) { +#ifndef MS_SUA + fprintf(stderr,"%s %d buffer offset overflow!!\n",__FILE__,__LINE__) ; + fprintf(stderr," ---> p = %d,\n ---> displace[%d] = %d,\n ---> noutbuf_loc = %d,\n ---> typesize = %d\n", + p, p, displace[p], noutbuf_loc, *typesize); +#endif + MPI_Abort(MPI_COMM_WORLD,1) ; + } } + } - /* multiply everything by the size of the type */ - for ( p = 0 ; p < ntasks ; p++ ) { - displace[p] *= *typesize ; - sendcounts[p] *= *typesize ; + /* handle different sized data types appropriately. */ + ierr = MPI_Type_match_size (MPI_TYPECLASS_REAL, *typesize, &dtype); + if (MPI_SUCCESS != ierr) { + ierr = MPI_Type_match_size (MPI_TYPECLASS_INTEGER, *typesize, &dtype); + if (MPI_SUCCESS != ierr) { +#ifndef MS_SUA + fprintf(stderr,"%s %d FATAL ERROR: unhandled typesize = %d!!\n", __FILE__,__LINE__,*typesize) ; +#endif + MPI_Abort(MPI_COMM_WORLD,1) ; } } - MPI_Scatterv( inbuf , sendcounts , displace, MPI_CHAR , - outbuf , *noutbuf * *typesize , MPI_CHAR , - root_task , *comm ) ; + MPI_Scatterv( inbuf, sendcounts, displace, dtype, + outbuf, *noutbuf, dtype, + root_task, *comm ) ; free(sendcounts) ; free(displace) ; @@ -241,4 +277,3 @@ rlim_ () } #endif #endif -