-
Notifications
You must be signed in to change notification settings - Fork 3
/
flake.nix
160 lines (154 loc) · 4.92 KB
/
flake.nix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
{
description = "JAX implementation of instant-ngp (NeRF part)";
inputs = {
nixpkgs.url = "github:nixos/nixpkgs/22.11";
nixpkgs-with-nvidia-driver-fix.url = "github:nixos/nixpkgs/pull/222762/head";
flake-utils.url = "github:numtide/flake-utils/3db36a8b464d0c4532ba1c7dda728f4576d6d073";
nixgl = {
url = "github:guibou/nixgl/c917918ab9ebeee27b0dd657263d3f57ba6bb8ad";
inputs = {
nixpkgs.follows = "nixpkgs";
flake-utils.follows = "flake-utils";
};
};
};
outputs = inputs@{ self, nixpkgs, flake-utils, ... }: let
deps = import ./deps;
in flake-utils.lib.eachSystem [ "x86_64-linux" "aarch64-linux" ] (system: let
inherit (nixpkgs) lib;
basePkgs = import nixpkgs {
inherit system;
overlays = [
self.overlays.default
];
};
in {
devShells = let
pyVer = "310";
py = "python${pyVer}";
jaxOverlays = final: prev: {
# avoid rebuilding opencv4 with cuda for tensorflow-datasets
opencv4 = prev.opencv4.override {
enableCuda = false;
};
${py} = prev.${py}.override {
packageOverrides = finalScope: prevScope: {
jax = prevScope.jax.overridePythonAttrs (o: { doCheck = false; });
jaxlib = prevScope.jaxlib-bin;
flax = prevScope.flax.overridePythonAttrs (o: {
buildInputs = o.buildInputs ++ [ prevScope.pyyaml ];
doCheck = false;
});
tensorflow = prevScope.tensorflow.override {
# we only use tensorflow-datasets for data loading, it does not need to be built
# with cuda support (building with cuda support is too demanding).
cudaSupport = false;
};
};
};
};
overlays = [
inputs.nixgl.overlays.default
self.overlays.default
jaxOverlays
];
cudaPkgs = import nixpkgs {
inherit system overlays;
config = {
allowUnfree = true;
cudaSupport = true;
packageOverrides = pkgs: {
linuxPackages = (import inputs.nixpkgs-with-nvidia-driver-fix {}).linuxPackages;
};
};
};
cpuPkgs = import nixpkgs {
inherit system overlays;
config = {
allowUnfree = true;
cudaSupport = false; # NOTE: disable cuda for cpu env
};
};
mkPythonDeps = { pp, extraPackages }: with pp; [
ipython
tqdm
icecream
pillow
ipdb
colorama
imageio
ffmpeg-python
pydantic
natsort
GitPython
pkgs.dearpygui
pkgs.pycolmap
pkgs.tyro
tensorflow
keras
jaxlib-bin
jax
optax
flax
pillow
matplotlib
] ++ extraPackages;
commonShellHook = ''
export PYTHONBREAKPOINT=ipdb.set_trace
export PYTHONDONTWRITEBYTECODE=1
export PYTHONUNBUFFERED=1
[[ "$-" == *i* ]] && exec "$SHELL"
'';
in rec {
default = cudaDevShell;
cudaDevShell = let # impure
isWsl = builtins.pathExists /usr/lib/wsl/lib;
in cudaPkgs.mkShell {
name = "cuda";
buildInputs = [
cudaPkgs.colmap-locked
cudaPkgs.ffmpeg
(cudaPkgs.${py}.withPackages (pp: mkPythonDeps {
inherit pp;
extraPackages = with pp; [
pkgs.spherical-harmonics-encoding-jax
pkgs.volume-rendering-jax
pkgs.jax-tcnn
];
}))
];
# REF:
# <https://github.com/google/jax/issues/5723#issuecomment-1339655621>
XLA_FLAGS = with builtins; let
nvidiaDriverVersion =
head (match ".*Module ([0-9\\.]+) .*" (readFile /proc/driver/nvidia/version));
nvidiaDriverVersionMajor = lib.toInt (head (splitVersion nvidiaDriverVersion));
in lib.optionalString
(!isWsl && nvidiaDriverVersionMajor <= 470)
"--xla_gpu_force_compilation_parallelism=1";
shellHook = ''
source <(sed -Ee '/\$@/d' ${lib.getExe cudaPkgs.nixgl.nixGLIntel})
'' + (if isWsl
then ''export LD_LIBRARY_PATH=/usr/lib/wsl/lib''${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}''
else ''source <(sed -Ee '/\$@/d' ${lib.getExe cudaPkgs.nixgl.auto.nixGLNvidia}*)''
) + "\n" + commonShellHook;
};
cpuDevShell = cpuPkgs.mkShell {
name = "cpu";
buildInputs = [
cpuPkgs.colmap-locked
cpuPkgs.ffmpeg
(cpuPkgs.${py}.withPackages (pp: mkPythonDeps {
inherit pp;
extraPackages = [];
}))
];
shellHook = ''
'' + commonShellHook;
};
};
packages = deps.packages basePkgs;
}) // {
overlays.default = deps.overlay;
};
}