diff --git a/qa/L0_shared_memory/shared_memory_test.py b/qa/L0_shared_memory/shared_memory_test.py index 63f3700c05..f40531b2e6 100755 --- a/qa/L0_shared_memory/shared_memory_test.py +++ b/qa/L0_shared_memory/shared_memory_test.py @@ -47,6 +47,7 @@ class SystemSharedMemoryTestBase(tu.TestResultCollector): DEFAULT_SHM_BYTE_SIZE = 64 + SYS_PAGE_SIZE = os.sysconf("SC_PAGE_SIZE") def setUp(self): self._setup_client() @@ -305,11 +306,10 @@ def test_large_shm_register_offset(self): # Test for large offset error_msg = [] - page_size = os.sysconf("SC_PAGE_SIZE") # Create a large shm size (page_size * 1024 is large enough to reproduce a segfault). # Register offset at 1 page before the end of the shm region to give enough space for the input/output data. - create_byte_size = page_size * 1024 - register_offset = page_size * 1023 + create_byte_size = self.SYS_PAGE_SIZE * 1024 + register_offset = self.SYS_PAGE_SIZE * 1023 self._configure_server( create_byte_size=create_byte_size, register_offset=register_offset, @@ -372,7 +372,12 @@ def test_unregisterall(self): def test_infer_offset_out_of_bound(self): # Shared memory offset outside output region - Throws error error_msg = [] - self._configure_server() + create_byte_size = self.SYS_PAGE_SIZE + self.DEFAULT_SHM_BYTE_SIZE + register_offset = self.SYS_PAGE_SIZE + self._configure_server( + create_byte_size=create_byte_size, + register_offset=register_offset, + ) if self.protocol == "http": # -32 when placed in an int64 signed type, to get a negative offset # by overflowing @@ -402,8 +407,13 @@ def test_infer_offset_out_of_bound(self): def test_infer_byte_size_out_of_bound(self): # Shared memory byte_size outside output region - Throws error error_msg = [] - self._configure_server() - offset = 60 + create_byte_size = self.SYS_PAGE_SIZE + self.DEFAULT_SHM_BYTE_SIZE + register_offset = self.SYS_PAGE_SIZE + self._configure_server( + create_byte_size=create_byte_size, + register_offset=register_offset, + ) + offset = 1 byte_size = self.DEFAULT_SHM_BYTE_SIZE iu.shm_basic_infer( diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index a2d52f5f48..bde10b525c 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -483,14 +483,11 @@ SharedMemoryManager::GetMemoryInfo( } // validate offset - size_t shm_region_end = 0; - if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) { - shm_region_end = it->second->offset_; - } + size_t shm_region_size = 0; if (it->second->byte_size_ > 0) { - shm_region_end += it->second->byte_size_ - 1; + shm_region_size += it->second->byte_size_; } - if (offset > shm_region_end) { + if (offset >= shm_region_size) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string("Invalid offset for shared memory region: '" + name + "'") @@ -510,8 +507,8 @@ SharedMemoryManager::GetMemoryInfo( } // validate byte_size + offset is within memory bounds - size_t total_req_shm = offset + byte_size - 1; - if (total_req_shm > shm_region_end) { + size_t total_req_shm = offset + byte_size; + if (total_req_shm > shm_region_size) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string(