Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for offline total segmentator training #504

Merged
merged 108 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
f073436
Add timm vision encoder network
ioangatop Apr 24, 2024
d112070
Add conv2d and linear decoders
ioangatop Apr 24, 2024
86ad741
Merge branch '381-add-timm-encoder-networks' into 382-add-semanticseg…
ioangatop Apr 24, 2024
97442bd
Add semantic segmentation module
ioangatop Apr 24, 2024
60691e5
improve docs
ioangatop Apr 24, 2024
69d4b29
add config
ioangatop Apr 24, 2024
849a270
update
ioangatop Apr 24, 2024
2259004
add segmentation visualiser callback
ioangatop Apr 26, 2024
2b5a1ce
Add simple decoder networks for segmentation tasks (#404)
ioangatop Apr 30, 2024
24de9e3
Move model freezing functionality to `configure_model`
ioangatop Apr 30, 2024
a7e65e5
Add `timm` encoder networks (#403)
ioangatop Apr 30, 2024
3f663f2
update
ioangatop May 3, 2024
b8b6649
update callback
ioangatop May 3, 2024
2fb0ac8
Merge branch 'main' of https://github.com/kaiko-ai/eva
ioangatop May 6, 2024
ad392e8
Add SemanticSegmentation module (#410)
ioangatop May 6, 2024
124ee87
Merge branch 'main' of https://github.com/kaiko-ai/eva
ioangatop May 6, 2024
4ef3afb
Add `TotalSegmentator2D` segmentation downstream task (#413)
ioangatop May 7, 2024
c30419e
Merge branch 'main' into 385-create-a-callback-to-visualise-the-segme…
ioangatop May 7, 2024
c92d46a
udpates
ioangatop May 7, 2024
1914085
update with 401
ioangatop May 7, 2024
7a47d08
updates
ioangatop May 7, 2024
1396803
Merge branch '402-aggregated-feature-segmentation-downstream-evaluati…
ioangatop May 7, 2024
66c2129
Add dice score as training metric
ioangatop May 7, 2024
624bdce
updates
ioangatop May 7, 2024
23c205f
add visualization callback
ioangatop May 7, 2024
07fbd08
Add dice score in `TotalSegmentator2D` task (#423)
ioangatop May 7, 2024
723a08c
add segmentation logger
ioangatop May 7, 2024
a05565a
Merge branch '402-aggregated-feature-segmentation-downstream-evaluati…
ioangatop May 7, 2024
cadbaec
rm dev files
ioangatop May 7, 2024
b27aaaa
Merge branch '385-create-a-callback-to-visualise-the-segmentation-res…
ioangatop May 7, 2024
bc681c8
update with main
ioangatop May 7, 2024
ffc2768
fix lint tests
ioangatop May 7, 2024
bae6ae6
Allow to parametrize the classes of the total segmentation 2d dataset
ioangatop May 10, 2024
5080dc7
Merge branch '402-aggregated-feature-segmentation-downstream-evaluati…
ioangatop May 10, 2024
0fc8395
lint
ioangatop May 10, 2024
8d45680
Merge branch '434-allow-to-use-subclasses-in-totalsegmentator2d' of h…
ioangatop May 10, 2024
69c58ca
update
ioangatop May 10, 2024
8586b20
Merge branch '385-create-a-callback-to-visualise-the-segmentation-res…
ioangatop May 10, 2024
50001d5
updates
ioangatop May 10, 2024
a6fd246
Allow to use subclasses in `TotalSegmentator2D` (#435)
ioangatop May 13, 2024
db5a578
Create a callback to visualise the segmentation results (#424)
ioangatop May 13, 2024
7656ae6
Improve the mask loading in `TotalSegmentator2D` (#440)
ioangatop May 15, 2024
702f634
Add per class metrics dice score in `TotalSegmentator2D` (#447)
ioangatop May 16, 2024
434cacb
Support `int16` training on `TotalSegementator2D` (#443)
ioangatop May 16, 2024
afde9b1
update with 402
ioangatop May 17, 2024
532504b
refactor embeddings writer
ioangatop May 21, 2024
9a5eca7
minor rename
ioangatop May 21, 2024
0b95b74
improve docs
ioangatop May 22, 2024
a0674d4
add segmentation writer
ioangatop May 22, 2024
9166ad7
add seg writer
ioangatop May 22, 2024
3a85809
working writer
ioangatop May 22, 2024
377c353
minor update
ioangatop May 22, 2024
35bb532
Normalisations and transforms for `int16` image types (#457)
ioangatop May 22, 2024
3f3794e
Merge branch '402-aggregated-feature-segmentation-downstream-evaluati…
ioangatop May 22, 2024
09a514c
update
ioangatop May 23, 2024
ffc6186
add config callback
ioangatop May 23, 2024
6dd7aa2
update all config files
ioangatop May 23, 2024
3b71e7e
fix loggin
ioangatop May 23, 2024
22a2a92
update parsing method
ioangatop May 23, 2024
58f73f9
added multi instance learning support
nkaenzig May 23, 2024
b5debfc
udpates
ioangatop May 24, 2024
48654b8
Merge branch '460-refactor-embeddingswriter-class' of https://github.…
ioangatop May 24, 2024
a45b7be
Merge branch '460-refactor-embeddingswriter-class' into 436-add-offli…
ioangatop May 24, 2024
9187447
updates
ioangatop May 29, 2024
c54b63b
update with main
ioangatop Jun 3, 2024
d58f6b4
merge with main
ioangatop Jun 3, 2024
7eeaa52
udpates
ioangatop Jun 3, 2024
ff69754
fix flow
ioangatop Jun 3, 2024
1f235cd
Merge branch 'main' into 436-add-offline-total-segmentator-configuration
ioangatop Jun 4, 2024
735f305
updates
ioangatop Jun 5, 2024
7f34d39
udpates
ioangatop Jun 6, 2024
8fad9aa
Add support for multi-level feature embeddings
ioangatop Jun 6, 2024
218dba3
updates
ioangatop Jun 6, 2024
d1ba913
Fix default segmentation metrics
ioangatop Jun 6, 2024
6cba65b
Merge branch '502-fix-default-segmentation-metrics' into 436-add-offl…
ioangatop Jun 6, 2024
abaa2cd
Merge branch '500-support-for-multi-level-embeddings-training-in-segm…
ioangatop Jun 6, 2024
91608e1
updates
ioangatop Jun 6, 2024
541c594
Finish offline segmentation training
ioangatop Jun 6, 2024
c43b8ad
Fix default segmentation metrics (#503)
ioangatop Jun 6, 2024
b009aec
Add support for multi-level embeddings training in segmentation tasks…
ioangatop Jun 6, 2024
19c92cd
merge with main
ioangatop Jun 6, 2024
cf54cad
update with main
ioangatop Jun 6, 2024
07b135a
updates
ioangatop Jun 6, 2024
352f2a6
updates
ioangatop Jun 6, 2024
f02b773
rm dev files
ioangatop Jun 6, 2024
2abfe16
updates
ioangatop Jun 6, 2024
f3be716
updates
ioangatop Jun 10, 2024
0edcf0d
merge with main
ioangatop Jul 3, 2024
f978b89
udpates
ioangatop Jul 3, 2024
f153a28
Merge branch 'main' into 436-add-offline-total-segmentator-configuration
ioangatop Jul 3, 2024
2e40f0b
resolve conflicts
ioangatop Jul 3, 2024
dd881f6
rm old files
ioangatop Jul 3, 2024
b581bcf
fixes
ioangatop Jul 3, 2024
5407582
add monusac offline config
ioangatop Jul 3, 2024
bb3ffdb
minor fix
ioangatop Jul 3, 2024
c5775f2
enable offline segmentation for multi-level embeddings
ioangatop Jul 4, 2024
8d7c7f2
fix consep filename
ioangatop Jul 4, 2024
6379bd1
export metadata
ioangatop Jul 9, 2024
5b948a6
clean up
ioangatop Jul 9, 2024
5a2e665
fixes in total seg
ioangatop Jul 9, 2024
13c7c24
add tests
ioangatop Jul 9, 2024
431f4f9
updates
ioangatop Jul 9, 2024
bb30ae8
rm unused functions
ioangatop Jul 9, 2024
6596ae5
Merge branch 'main' into 436-add-offline-total-segmentator-configuration
ioangatop Jul 10, 2024
c41f92c
fix lint
ioangatop Jul 10, 2024
2d255e3
Merge branch 'main' into 436-add-offline-total-segmentator-configuration
ioangatop Jul 10, 2024
4bd9416
update counter
ioangatop Jul 10, 2024
51d59c4
fix lint
ioangatop Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions configs/vision/dino_vit/offline/consep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
---
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224_dino}/consep}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513}
log_every_n_steps: 6
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: eva.vision.callbacks.SemanticSegmentationLogger
init_args:
log_every_n_epochs: 1
log_images: false
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassJaccardIndex}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 100
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.SegmentationEmbeddingsWriter
init_args:
output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224_dino}/consep
dataloader_idx_map:
0: train
1: val
metadata_keys: ["coords"]
backbone:
class_path: eva.vision.models.networks.encoders.TimmEncoder
init_args:
model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224_dino}
pretrained: ${oc.env:MODEL_PRETRAINED, true}
out_indices: ${oc.env:TIMM_MODEL_OUT_INDICES, 3}
checkpoint_path: &CHECKPOINT_PATH ${oc.env:CHECKPOINT_PATH, null}
model_arguments:
dynamic_img_size: true
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS
init_args:
in_features: ${oc.env:DECODER_IN_FEATURES, 1152}
num_classes: &NUM_CLASSES 8
criterion: torch.nn.CrossEntropyLoss
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
weight_decay: 0.01
lr_scheduler:
class_path: torch.optim.lr_scheduler.PolynomialLR
init_args:
total_iters: *MAX_STEPS
power: 0.9
metrics:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics
init_args:
num_classes: *NUM_CLASSES
- class_path: eva.core.metrics.wrappers.ClasswiseWrapper
init_args:
metric:
class_path: torchmetrics.classification.MulticlassF1Score
init_args:
num_classes: *NUM_CLASSES
average: null
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.EmbeddingsSegmentationDataset
init_args: &DATASET_ARGS
root: *DATASET_EMBEDDINGS_ROOT
manifest_file: manifest.csv
split: train
val:
class_path: eva.vision.datasets.EmbeddingsSegmentationDataset
init_args:
<<: *DATASET_ARGS
split: val
predict:
- class_path: eva.vision.datasets.CoNSeP
init_args: &PREDICT_DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data}/consep
split: train
sampler:
class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler
init_args:
max_samples: 25
transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
size: ${oc.env:RESIZE_DIM, 224}
mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: eva.vision.datasets.CoNSeP
init_args:
<<: *PREDICT_DATASET_ARGS
split: val
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64}
shuffle: true
val:
batch_size: *BATCH_SIZE
predict:
batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
124 changes: 124 additions & 0 deletions configs/vision/dino_vit/offline/monusac.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
---
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224_dino}/monusac}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513}
log_every_n_steps: 6
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: eva.vision.callbacks.SemanticSegmentationLogger
init_args:
log_every_n_epochs: 1
log_images: false
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassJaccardIndex}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 100
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.SegmentationEmbeddingsWriter
init_args:
output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224_dino}/monusac
dataloader_idx_map:
0: train
1: test
backbone:
class_path: eva.vision.models.networks.encoders.TimmEncoder
init_args:
model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224_dino}
pretrained: ${oc.env:MODEL_PRETRAINED, true}
out_indices: ${oc.env:TIMM_MODEL_OUT_INDICES, 1}
checkpoint_path: &CHECKPOINT_PATH ${oc.env:CHECKPOINT_PATH, null}
model_arguments:
dynamic_img_size: true
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS
init_args:
in_features: ${oc.env:DECODER_IN_FEATURES, 384}
num_classes: &NUM_CLASSES 8
criterion: torch.nn.CrossEntropyLoss
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
weight_decay: 0.01
lr_scheduler:
class_path: torch.optim.lr_scheduler.PolynomialLR
init_args:
total_iters: *MAX_STEPS
power: 0.9
metrics:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics
init_args:
num_classes: *NUM_CLASSES
- class_path: eva.core.metrics.wrappers.ClasswiseWrapper
init_args:
metric:
class_path: torchmetrics.classification.MulticlassF1Score
init_args:
num_classes: *NUM_CLASSES
average: null
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.EmbeddingsSegmentationDataset
init_args: &DATASET_ARGS
root: *DATASET_EMBEDDINGS_ROOT
manifest_file: manifest.csv
split: train
val:
class_path: eva.vision.datasets.EmbeddingsSegmentationDataset
init_args:
<<: *DATASET_ARGS
split: test
predict:
- class_path: eva.vision.datasets.MoNuSAC
init_args: &PREDICT_DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data}/monusac
split: train
download: ${oc.env:DOWNLOAD_DATA, false}
# Set `download: true` to download the dataset from https://monusac-2020.grand-challenge.org/Data/
# The MoNuSAC dataset is distributed under the following license:
# "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International"
# (see: https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode)
transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
size: ${oc.env:RESIZE_DIM, 224}
mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: eva.vision.datasets.MoNuSAC
init_args:
<<: *PREDICT_DATASET_ARGS
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64}
shuffle: true
val:
batch_size: *BATCH_SIZE
predict:
batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
119 changes: 119 additions & 0 deletions configs/vision/dino_vit/offline/total_segmentator_2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}/total_segmentator_2d}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000}
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: eva.vision.callbacks.SemanticSegmentationLogger
init_args:
log_every_n_steps: 1000
log_images: false
- class_path: eva.callbacks.SegmentationEmbeddingsWriter
init_args:
output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:DINO_BACKBONE, dino_vits16}/total_segmentator_2d
dataloader_idx_map:
0: train
1: val
2: test
metadata_keys: ["slice_index"]
backbone:
class_path: eva.vision.models.networks.encoders.TimmEncoder
init_args:
model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224_dino}
pretrained: ${oc.env:MODEL_PRETRAINED, true}
out_indices: ${oc.env:TIMM_MODEL_OUT_INDICES, 1}
model_arguments:
dynamic_img_size: true
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoder
init_args:
layers:
class_path: torch.nn.Conv2d
init_args:
in_channels: ${oc.env:DECODER_IN_FEATURES, 384}
out_channels: &NUM_CLASSES 118
kernel_size: [1, 1]
criterion: torch.nn.CrossEntropyLoss
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.0001
weight_decay: 0.05
lr_scheduler:
class_path: torch.optim.lr_scheduler.PolynomialLR
init_args:
total_iters: *MAX_STEPS
power: 0.9
metrics:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics
init_args:
num_classes: *NUM_CLASSES
- class_path: eva.core.metrics.wrappers.ClasswiseWrapper
init_args:
metric:
class_path: torchmetrics.classification.MulticlassF1Score
init_args:
num_classes: *NUM_CLASSES
average: null
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.EmbeddingsSegmentationDataset
init_args: &DATASET_ARGS
root: *DATASET_EMBEDDINGS_ROOT
manifest_file: manifest.csv
split: train
val:
class_path: eva.vision.datasets.EmbeddingsSegmentationDataset
init_args:
<<: *DATASET_ARGS
split: val
predict:
- class_path: eva.vision.datasets.TotalSegmentator2D
init_args: &PREDICT_DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data}/total_segmentator
split: train
download: ${oc.env:DOWNLOAD_DATA, false}
# Set `download: true` to download the dataset from https://zenodo.org/records/10047292
# The TotalSegmentator dataset is distributed under the following license:
# "Creative Commons Attribution 4.0 International"
# (see: https://creativecommons.org/licenses/by/4.0/deed.en)
transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
init_args:
size: ${oc.env:RESIZE_DIM, 224}
mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: eva.vision.datasets.TotalSegmentator2D
init_args:
<<: *PREDICT_DATASET_ARGS
split: val
- class_path: eva.vision.datasets.TotalSegmentator2D
init_args:
<<: *PREDICT_DATASET_ARGS
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64}
shuffle: true
val:
batch_size: *BATCH_SIZE
test:
batch_size: *BATCH_SIZE
predict:
batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
8 changes: 4 additions & 4 deletions configs/vision/dino_vit/online/total_segmentator_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ trainer:
- class_path: eva.vision.callbacks.SemanticSegmentationLogger
init_args:
log_every_n_epochs: 1
mean: &NORMALIZE_MEAN [0.5, 0.5, 0.5]
std: &NORMALIZE_STD [0.5, 0.5, 0.5]
mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: best
Expand All @@ -36,8 +36,8 @@ model:
encoder:
class_path: eva.vision.models.networks.encoders.TimmEncoder
init_args:
model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}
pretrained: true
model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224_dino}
pretrained: ${oc.env:MODEL_PRETRAINED, true}
out_indices: ${oc.env:TIMM_MODEL_OUT_INDICES, 1}
model_arguments:
dynamic_img_size: true
Expand Down
Loading