Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
158 commits
Select commit Hold shift + click to select a range
9805d5c
upgrade numpy to 2.0
lxning Sep 11, 2025
a247511
update setuptool
lxning Sep 12, 2025
b9ce4bb
fix scikit-learn 1.4.2 error
lxning Sep 12, 2025
a78e416
test scipy 1.13.0
lxning Sep 12, 2025
de01800
fix typo
lxning Sep 12, 2025
d348ed3
test scipy 1.10.0
lxning Sep 12, 2025
9c9f24e
try numpy 2.1.0
lxning Sep 12, 2025
13e3f1d
try numpy 2.1.0
lxning Sep 12, 2025
22f1b15
set numba 0.61.0
lxning Sep 12, 2025
e4acfd1
set pyyaml 5.4.1
lxning Sep 12, 2025
b46448d
set pyyaml 6.0.1
lxning Sep 12, 2025
9f66b73
set cryptography 45.0.5
lxning Sep 12, 2025
3a88403
set requests 2.32.3
lxning Sep 12, 2025
eda94ea
fix image name
lxning Sep 12, 2025
d567cf6
set panda 2.2.0
lxning Sep 12, 2025
f5677b6
set panda 2.2.3
lxning Sep 12, 2025
85b7540
set python-dateutil==2.8.2
lxning Sep 12, 2025
d949999
set pyarrow 17.0.0
lxning Sep 12, 2025
4522abf
replace rabit with dask-based api
lxning Sep 12, 2025
a3d19d2
set protobuf 5.26
lxning Sep 12, 2025
f9de2dd
install pyarrow in container
lxning Sep 12, 2025
c56031f
set tbb 2022.2.0
lxning Sep 12, 2025
edab2d8
try mlio-py with pyarrow 17.0.0
lxning Sep 13, 2025
077edc3
try mlio-py with pyarrow 17.0.0
lxning Sep 13, 2025
5d07aa2
try install mlio
lxning Sep 13, 2025
e13d8f4
try install mlio
lxning Sep 13, 2025
ff431f5
hack mlio
lxning Sep 13, 2025
08ce40e
hack mlio
lxning Sep 13, 2025
20d4e69
try protobuf 3.20.1
lxning Sep 13, 2025
a219214
set dask 2024.10.0
lxning Sep 13, 2025
59a91aa
set dask 2024.9.0
lxning Sep 13, 2025
7324839
set psutil 5.8.0
lxning Sep 13, 2025
e759563
update train test minor version
lxning Sep 15, 2025
28662dc
set matplotlib==3.6.3
lxning Sep 16, 2025
3515083
set matplotlib==3.6.3
lxning Sep 16, 2025
2da93ab
Trigger Build
lxning Sep 23, 2025
01ad9b1
test dask migration
lxning Sep 23, 2025
054c09c
test xgb migration
lxning Sep 23, 2025
cff59cb
test xgb rabit
lxning Sep 23, 2025
9c3fe62
test xgb rapit migration
lxning Sep 24, 2025
103c0c9
test dask expr backend migration
lxning Sep 24, 2025
0e2a074
test rabit.tracker_print migration
lxning Sep 24, 2025
d3e4ee5
test rabit and libsvm migration
lxning Sep 24, 2025
5e6a1b2
test rabit and dask
lxning Sep 24, 2025
c4cd3e6
test dask migratinon
lxning Sep 24, 2025
be02e7c
test rabit
lxning Sep 24, 2025
7e4b844
test _aggregate_predictions
lxning Sep 24, 2025
56c302d
recover checkpointing.py distributed.py
lxning Sep 24, 2025
57ffac5
rabit deprecate
lxning Sep 25, 2025
15e7484
set env var
lxning Sep 25, 2025
2acc1ec
test distributed.py
lxning Sep 25, 2025
e9cea56
test distributed.py
lxning Sep 25, 2025
c003caf
replace rabit with dask
lxning Sep 25, 2025
41d1794
replace rabit with collective
lxning Sep 25, 2025
1d9372e
replace rabit with collective
lxning Sep 25, 2025
af3b5f8
replace rabit with collective
lxning Sep 25, 2025
bcd2bc9
fmt
lxning Sep 25, 2025
6f982e8
fix sklearn api deprecations
lxning Sep 26, 2025
7620c5d
backward compatible for unit test
lxning Sep 26, 2025
5f9ec05
fmt
lxning Sep 26, 2025
82f9b45
set matplotlib
lxning Sep 26, 2025
1580cdc
set matplotlib==3.6.3
lxning Sep 26, 2025
97e704c
set matplotlib==3.9.2
lxning Sep 26, 2025
7f629d8
set matplotlib==3.9.2
lxning Sep 26, 2025
be49c37
fix model name
lxning Sep 26, 2025
a5b6f48
fix distributed training save model
lxning Sep 26, 2025
90c3163
fix distributed training save model
lxning Sep 26, 2025
5a1611b
fix distributed training save model
lxning Sep 26, 2025
842b74d
fix distributed training save model
lxning Sep 26, 2025
3a90591
fix distributed training save model
lxning Sep 26, 2025
147731c
fix distributed training save model
lxning Sep 26, 2025
cbc6057
fix distributed training save model
lxning Sep 27, 2025
d13c5f3
fix distributed training save model
lxning Sep 27, 2025
ff91b36
fix distributed training save model
lxning Sep 27, 2025
859e13a
fix distributed training save model
lxning Sep 27, 2025
0ac1a66
debug
lxning Sep 28, 2025
3d28094
debug master host
lxning Sep 28, 2025
846477c
debug master host
lxning Sep 28, 2025
9d1adea
debug master host
lxning Sep 28, 2025
5bac086
debug master host
lxning Sep 28, 2025
695a28d
debug master host
lxning Sep 28, 2025
609fc1f
debug master host
lxning Sep 28, 2025
50a302e
debug master host
lxning Sep 28, 2025
cf13b97
debug master host
lxning Sep 28, 2025
242d9b2
debug master host
lxning Sep 29, 2025
b0589a2
debug master host
lxning Sep 29, 2025
22f74fc
debug master host
lxning Sep 29, 2025
a5c96a0
debug master host
lxning Sep 29, 2025
8e63e63
debug master host
lxning Sep 29, 2025
fb68231
debug master host
lxning Sep 29, 2025
92467f4
debug master host
lxning Sep 29, 2025
e4fa859
debug master host
lxning Sep 29, 2025
d3b5062
debug master host
lxning Sep 29, 2025
1abfe25
debug master host
lxning Sep 29, 2025
2fdd648
debug master host
lxning Sep 29, 2025
a805064
debug master host
lxning Sep 29, 2025
ac7a788
check xgboost 2.1.1
lxning Sep 30, 2025
41dca75
check xgboost 2.1.1
lxning Sep 30, 2025
b7881bf
check xgboost 2.1.1
lxning Sep 30, 2025
bfd87e7
check xgboost 2.1.1
lxning Sep 30, 2025
a3444d1
check xgboost 2.1.1
lxning Sep 30, 2025
216c341
check xgboost 2.1.1
lxning Sep 30, 2025
51d0ebc
check xgboost 2.1.1
lxning Sep 30, 2025
21fc052
check xgboost 2.1.1
lxning Sep 30, 2025
ebf9e7a
check xgboost 2.1.1
lxning Sep 30, 2025
9aef1c7
check xgboost 2.1.1
lxning Sep 30, 2025
04d5eb7
check xgboost 2.1.1
lxning Oct 1, 2025
2e2439f
check xgboost 2.1.1
lxning Oct 1, 2025
849b911
check xgboost 2.1.1
lxning Oct 1, 2025
d9596d7
check xgboost 2.1.1
lxning Oct 1, 2025
d511676
check xgboost 2.1.1
lxning Oct 1, 2025
253934d
check xgboost 2.1.1
lxning Oct 1, 2025
c7961ac
check xgboost 2.1.0
lxning Oct 1, 2025
9ef7cf9
check xgboost 2.1.0
lxning Oct 1, 2025
7939d45
check xgboost 2.1.0
lxning Oct 1, 2025
8451e22
check xgboost 2.1.0
lxning Oct 1, 2025
db2284d
check xgboost 2.1.0
lxning Oct 1, 2025
bf6a40d
check xgboost 2.1.0
lxning Oct 1, 2025
a6ec50f
check xgboost 2.1.0
lxning Oct 1, 2025
e3fb795
check xgboost 2.1.0
lxning Oct 1, 2025
c9021ed
check xgboost 2.1.0
lxning Oct 1, 2025
b40d83e
check xgboost 2.1.0
lxning Oct 1, 2025
2c746dd
check xgboost 2.1.0
lxning Oct 1, 2025
a3193a3
check xgboost 2.1.0
lxning Oct 1, 2025
076c786
check xgboost 2.1.0
lxning Oct 1, 2025
23d6034
check xgboost 2.1.0
lxning Oct 1, 2025
a94d1c2
check xgboost 2.1.0
lxning Oct 2, 2025
09a3a70
check xgboost 2.1.0
lxning Oct 2, 2025
a42aa02
check xgboost 2.1.0
lxning Oct 2, 2025
3137263
check xgboost 2.1.0
lxning Oct 2, 2025
3581c03
check xgboost 2.1.0
lxning Oct 2, 2025
e687bb3
check xgboost 2.1.0
lxning Oct 2, 2025
cc0f366
check xgboost 2.1.0
lxning Oct 2, 2025
90b67cc
check xgboost 2.1.0
lxning Oct 2, 2025
965405d
check xgboost 2.1.0
lxning Oct 2, 2025
702bc16
check xgboost 2.1.0
lxning Oct 2, 2025
68592ba
Merge branch 'master' into lninga_dev
lxning Oct 3, 2025
d0f7492
test xgboost 3.0.5
lxning Oct 3, 2025
87e8ea4
test xgboost 3.0.5
lxning Oct 3, 2025
6ebd902
test xgboost 3.0.5
lxning Oct 3, 2025
445f2c6
test xgboost 3.0.5
lxning Oct 3, 2025
180d6f6
test xgboost 3.0.5
lxning Oct 3, 2025
7555efa
test xgboost 3.0.5
lxning Oct 3, 2025
27fb002
test xgboost 3.0.5
lxning Oct 3, 2025
71ac152
test xgboost 3.0.5
lxning Oct 4, 2025
fa5b3ce
test xgboost 3.0.5
lxning Oct 4, 2025
c78e283
test xgboost 3.0.5
lxning Oct 4, 2025
674206b
test xgboost 3.0.5
lxning Oct 4, 2025
bea2f33
test xgboost 3.0.5
lxning Oct 4, 2025
6d1ea7c
test xgboost 3.0.5
lxning Oct 4, 2025
2a2d155
test xgboost 3.0.5
lxning Oct 4, 2025
6cde744
test xgboost 3.0.5
lxning Oct 4, 2025
6cc8a49
test xgboost 3.0.5
lxning Oct 5, 2025
a6a10e2
rename 2.1.0 with 3.0.5
lxning Oct 6, 2025
2e34b84
test 3.0.5
lxning Oct 7, 2025
fa34fe3
test 3.0.5
lxning Oct 7, 2025
9ed4eff
cuda 12.0.0
lxning Oct 10, 2025
6f043d9
roll back to cuda 11.6.1
lxning Oct 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions docker/3.0.5/base/Dockerfile.cpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
ARG UBUNTU_VERSION=20.04
ARG CUDA_VERSION=11.6.1
ARG IMAGE_DIGEST=c2d95c9c6ff77da41cf0f2f9e8c5088f5b4db20c16a7566b808762f05b9032ef

# Build stage for SQLite compilation
FROM ubuntu:${UBUNTU_VERSION} as sqlite-builder
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
wget \
ca-certificates \
&& \
cd /tmp && \
wget https://www.sqlite.org/2025/sqlite-autoconf-3500200.tar.gz && \
tar xzf sqlite-autoconf-3500200.tar.gz && \
cd sqlite-autoconf-3500200 && \
./configure --prefix=/usr/local && \
make && \
make install && \
ldconfig && \
cd / && \
rm -rf /tmp/sqlite-autoconf-3500200 /tmp/sqlite-autoconf-3500200.tar.gz && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# Main image
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${UBUNTU_VERSION}@sha256:${IMAGE_DIGEST}

ARG MINICONDA_VERSION=24.7.1
ARG CONDA_CHECKSUM=684cda724bc37e3bbbb342e440fc4cac515c92e91a489eb4359feca35382894b
ARG CONDA_PY_VERSION=310
ARG CONDA_PKG_VERSION=24.7.1
ARG PYTHON_VERSION=3.10
ARG PYARROW_VERSION=17.0.0
ARG MLIO_VERSION=0.9.0
ARG XGBOOST_VERSION=3.0.5

ENV DEBIAN_FRONTEND=noninteractive
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8

# Python won’t try to write .pyc or .pyo files on the import of source modules
# Force stdin, stdout and stderr to be totally unbuffered. Good for logging
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
ENV PYTHONIOENCODING='utf-8'

RUN apt-key del 7fa2af80 && \
apt-get update && apt-get install -y --no-install-recommends wget && \
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb && \
dpkg -i cuda-keyring_1.0-1_all.deb && \
apt-get update && \
apt-get -y upgrade && \
apt-get -y install --no-install-recommends \
build-essential \
curl \
git \
jq \
libatlas-base-dev \
expat \
nginx \
openjdk-8-jdk-headless \
unzip \
wget \
apparmor \
linux-libc-dev \
libxml2 \
libgstreamer1.0-0 \
linux-libc-dev \
&& \
# MLIO build dependencies
# Official Ubuntu APT repositories do not contain an up-to-date version of CMake required to build MLIO.
# Kitware contains the latest version of CMake.
wget http://es.archive.ubuntu.com/ubuntu/pool/main/libf/libffi/libffi7_3.3-4_amd64.deb && \
dpkg -i libffi7_3.3-4_amd64.deb && \
apt-get -y install --no-install-recommends \
apt-transport-https \
ca-certificates \
gnupg \
software-properties-common \
&& \
wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | \
gpg --dearmor - | \
tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null && \
echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ bionic main' | tee /etc/apt/sources.list.d/kitware.list >/dev/null && \
apt-get update && \
rm /usr/share/keyrings/kitware-archive-keyring.gpg && \
apt-get install -y --no-install-recommends \
autoconf \
automake \
build-essential \
cmake \
cmake-data \
doxygen \
kitware-archive-keyring \
libcurl4-openssl-dev \
libssl-dev \
libtool \
ninja-build \
python3-dev \
python3-distutils \
python3-pip \
zlib1g-dev \
libxml2 \
zstd \
libsqlite3-0 \
&& \
python3 -m pip install --upgrade pip && \
python3 -m pip install --upgrade certifi && \
apt-get clean && \
# Node.js setup
mkdir -p /etc/apt/keyrings && \
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | \
gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | \
tee /etc/apt/sources.list.d/nodesource.list && \
apt-get update && \
apt-get install -y nodejs && \
npm install -g npm@latest && \
rm -rf /var/lib/apt/lists/*

# Install conda
RUN cd /tmp && \
curl -L --output /tmp/Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-py${CONDA_PY_VERSION}_${MINICONDA_VERSION}-0-Linux-x86_64.sh && \
echo "${CONDA_CHECKSUM} /tmp/Miniconda3.sh" | sha256sum -c - && \
bash /tmp/Miniconda3.sh -bfp /miniconda3 && \
rm /tmp/Miniconda3.sh

ENV PATH=/miniconda3/bin:${PATH}

# Install MLIO with Apache Arrow integration

# We could install mlio-py from conda, but it comes with extra support such as image reader that increases image size
# which increases training time. We build from source to minimize the image size.
RUN echo "conda ${CONDA_PKG_VERSION}" >> /miniconda3/conda-meta/pinned && \
# Conda configuration see https://conda.io/projects/conda/en/latest/configuration.html
conda config --system --set auto_update_conda false && \
conda config --system --set show_channel_urls true && \
echo "python ${PYTHON_VERSION}.*" >> /miniconda3/conda-meta/pinned && \
conda install -c conda-forge python=${PYTHON_VERSION} --solver classic && \
pip install requests==2.32.3 && \
conda install conda=${CONDA_PKG_VERSION} --solver classic && \
conda update -y conda && \
conda install -c conda-forge pyarrow=${PYARROW_VERSION} --solver classic && \
cd /miniconda3/pkgs/libgrpc-*/info/test/examples/node && \
npm install minimist@latest protobufjs@latest && \
# Remove Node.js, npm, and their dependencies
apt-get purge -y nodejs npm && \
apt-get autoremove -y && \
# Final cleanup
rm -rf /etc/apt/sources.list.d/nodesource.list \
/etc/apt/keyrings/nodesource.gpg \
/etc/apt/sources.list.d/kitware.list && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* && \
# Continue with the rest of the build process
cd /tmp && \
git clone --branch v${MLIO_VERSION} https://github.com/awslabs/ml-io.git mlio && \
cd mlio && \
sed -i 's/find_package(Arrow 14.0.1 REQUIRED/find_package(Arrow 17.0.0 REQUIRED/g' CMakeLists.txt && \
sed -i 's/pyarrow==14.0.1/pyarrow==17.0.0/g' src/mlio-py/setup.py && \
build-tools/build-dependency build/third-party all && \
mkdir -p build/release && \
cd build/release && \
cmake -GNinja -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_PREFIX_PATH="$(pwd)/../third-party" ../.. && \
cmake --build . && \
cmake --build . --target install && \
cmake -DMLIO_INCLUDE_PYTHON_EXTENSION=ON -DPYTHON_EXECUTABLE="/miniconda3/bin/python3" \
-DMLIO_INCLUDE_ARROW_INTEGRATION=ON ../.. && \
cmake --build . --target mlio-py && \
cmake --build . --target mlio-arrow && \
cd ../../src/mlio-py && \
python3 setup.py bdist_wheel && \
python3 -m pip install typing && \
python3 -m pip install --upgrade pip && \
python3 -m pip install dist/*.whl && \
cp -r /tmp/mlio/build/third-party/lib/libtbb* /usr/local/lib/ && \
ldconfig && \
rm -rf /tmp/mlio

# Copy compiled SQLite from builder stage
COPY --from=sqlite-builder /usr/local/bin/sqlite3 /usr/local/bin/sqlite3
COPY --from=sqlite-builder /usr/local/lib/libsqlite3.* /usr/local/lib/
COPY --from=sqlite-builder /usr/local/include/sqlite3*.h /usr/local/include/

# Update library cache and ensure /usr/local/bin is in PATH
RUN ldconfig && \
echo "/usr/local/lib" > /etc/ld.so.conf.d/sqlite3.conf && \
ldconfig

ENV PATH="/usr/local/bin:${PATH}"

RUN echo "sqlite3 "
# This command will check the version and print it to the build logs
RUN sqlite3 --version

RUN apt list --installed

# Install latest version of XGBoost
RUN python3 -m pip install --no-cache -I xgboost==${XGBOOST_VERSION} numpy==2.1.0 pyarrow==17.0.0 pandas==2.2.3
96 changes: 96 additions & 0 deletions docker/3.0.5/final/Dockerfile.cpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
ARG SAGEMAKER_XGBOOST_VERSION=3.0-5
ARG PYTHON_VERSION=3.10

FROM xgboost-container-base:${SAGEMAKER_XGBOOST_VERSION}-cpu-py3

ARG SAGEMAKER_XGBOOST_VERSION=3.0.5

########################
# Install dependencies #
########################
COPY requirements.txt /requirements.txt
RUN python3 -m pip install -r /requirements.txt && rm /requirements.txt

# Fix Python 3.10 compatibility for sagemaker-containers
# RUN python3 -c "import sys; sys.path.insert(0, '/miniconda3/lib/python3.10/site-packages'); \
# import sagemaker_containers._mapping as m; \
# import collections.abc; \
# setattr(collections, 'Mapping', collections.abc.Mapping); \
# exec(open('/miniconda3/lib/python3.10/site-packages/sagemaker_containers/_mapping.py').read().replace('collections.Mapping', 'collections.abc.Mapping'))" || \
# sed -i 's/collections\.Mapping/collections.abc.Mapping/g' /miniconda3/lib/python3.10/site-packages/sagemaker_containers/_mapping.py

RUN sed -i 's/collections\.Mapping/collections.abc.Mapping/g' /miniconda3/lib/python3.10/site-packages/sagemaker_containers/_mapping.py

# Install smdebug from source
RUN python3 -m pip install git+https://github.com/awslabs/[email protected]


###########################
# Copy wheel to container #
###########################
COPY dist/sagemaker_xgboost_container-2.0-py2.py3-none-any.whl /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl
RUN rm -rf /miniconda3/lib/python${PYTHON_VERSION}/site-packages/numpy-1.21.2.dist-info && \
python3 -m pip install --force-reinstall PyYAML==6.0.1 && \
python3 -m pip install --no-cache --no-deps /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl && \
python3 -m pip uninstall -y typing && \
rm /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl

##############
# DMLC PATCH #
##############
# TODO: remove after making contributions back to xgboost for tracker.py
# COPY src/sagemaker_xgboost_container/dmlc_patch/tracker.py \
# /miniconda3/lib/python${PYTHON_VERSION}/site-packages/xgboost/dmlc-core/tracker/dmlc_tracker/tracker.py

# # Include DMLC python code in PYTHONPATH to use RabitTracker
# ENV PYTHONPATH=$PYTHONPATH:/miniconda3/lib/python${PYTHON_VERSION}/site-packages/xgboost/dmlc-core/tracker

#######
# MMS #
#######
# Create MMS user directory
RUN useradd -m model-server
RUN mkdir -p /home/model-server/tmp && chown -R model-server /home/model-server

# Copy MMS configs
COPY docker/${SAGEMAKER_XGBOOST_VERSION}/resources/mms/config.properties.tmp /home/model-server
ENV XGBOOST_MMS_CONFIG=/home/model-server/config.properties

# Copy execution parameters endpoint plugin for MMS
RUN mkdir -p /tmp/plugins
COPY docker/${SAGEMAKER_XGBOOST_VERSION}/resources/mms/endpoints-1.0.jar /tmp/plugins
RUN chmod +x /tmp/plugins/endpoints-1.0.jar

# Create directory for models
RUN mkdir -p /opt/ml/models
RUN chmod +rwx /opt/ml/models

# Copy Dask configs
RUN mkdir /etc/dask
COPY docker/configs/dask_configs.yaml /etc/dask/

# Required label for multi-model loading
LABEL com.amazonaws.sagemaker.capabilities.multi-models=true

#####################
# Required ENV vars #
#####################
# Set SageMaker training environment variables
ENV SM_INPUT /opt/ml/input
ENV SM_INPUT_TRAINING_CONFIG_FILE $SM_INPUT/config/hyperparameters.json
ENV SM_INPUT_DATA_CONFIG_FILE $SM_INPUT/config/inputdataconfig.json
ENV SM_CHECKPOINT_CONFIG_FILE $SM_INPUT/config/checkpointconfig.json
# See: https://github.com/dmlc/xgboost/issues/7982#issuecomment-1379390906 https://github.com/dmlc/xgboost/pull/8257
ENV NCCL_SOCKET_IFNAME eth


# Set SageMaker serving environment variables
ENV SM_MODEL_DIR /opt/ml/model

# Set SageMaker entrypoints
ENV SAGEMAKER_TRAINING_MODULE sagemaker_xgboost_container.training:main
ENV SAGEMAKER_SERVING_MODULE sagemaker_xgboost_container.serving:main

EXPOSE 8080
ENV TEMP=/home/model-server/tmp
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true
98 changes: 98 additions & 0 deletions docker/3.0.5/resources/mms/ExecutionParameters.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package software.amazon.ai.mms.plugins.endpoint;

import com.google.gson.GsonBuilder;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Properties;
import software.amazon.ai.mms.servingsdk.Context;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
import software.amazon.ai.mms.servingsdk.annotations.Endpoint;
import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes;
import software.amazon.ai.mms.servingsdk.http.Request;
import software.amazon.ai.mms.servingsdk.http.Response;

/**
The modified endpoint source code for the jar used in this container.
You can create this endpoint by moving it by cloning the MMS repo:
> git clone https://github.com/awslabs/mxnet-model-server.git

Copy this file into plugins/endpoints/src/main/java/software/amazon/ai/mms/plugins/endpoints/
and then from the plugins directory, run:

> ./gradlew fJ

Modify file in plugins/endpoint/resources/META-INF/services/* to specify this file location

Then build the JAR:

> ./gradlew build

The jar should be available in plugins/endpoints/build/libs as endpoints-1.0.jar
**/
@Endpoint(
urlPattern = "execution-parameters",
endpointType = EndpointTypes.INFERENCE,
description = "Execution parameters endpoint")
public class ExecutionParameters extends ModelServerEndpoint {

@Override
public void doGet(Request req, Response rsp, Context ctx) throws IOException {
Properties prop = ctx.getConfig();
// 6 * 1024 * 1024
int maxRequestSize = Integer.parseInt(prop.getProperty("max_request_size", "6291456"));
SagemakerXgboostResponse response = new SagemakerXgboostResponse();
response.setMaxConcurrentTransforms(Integer.parseInt(prop.getProperty("NUM_WORKERS", "1")));
response.setBatchStrategy("MULTI_RECORD");
response.setMaxPayloadInMB(maxRequestSize / (1024 * 1024));
rsp.getOutputStream()
.write(
new GsonBuilder()
.setPrettyPrinting()
.create()
.toJson(response)
.getBytes(StandardCharsets.UTF_8));
}

/** Response for Model server endpoint */
public static class SagemakerXgboostResponse {
@SerializedName("MaxConcurrentTransforms")
private int maxConcurrentTransforms;

@SerializedName("BatchStrategy")
private String batchStrategy;

@SerializedName("MaxPayloadInMB")
private int maxPayloadInMB;

public SagemakerXgboostResponse() {
maxConcurrentTransforms = 4;
batchStrategy = "MULTI_RECORD";
maxPayloadInMB = 6;
}

public int getMaxConcurrentTransforms() {
return maxConcurrentTransforms;
}

public String getBatchStrategy() {
return batchStrategy;
}

public int getMaxPayloadInMB() {
return maxPayloadInMB;
}

public void setMaxConcurrentTransforms(int newMaxConcurrentTransforms) {
maxConcurrentTransforms = newMaxConcurrentTransforms;
}

public void setBatchStrategy(String newBatchStrategy) {
batchStrategy = newBatchStrategy;
}

public void setMaxPayloadInMB(int newMaxPayloadInMB) {
maxPayloadInMB = newMaxPayloadInMB;
}
}
}
Loading