-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pre-training on a subset of the original input channels #26
Comments
@gabrieltseng I should come back to this. It looks to me now that Presto expects the same elements to be masked across an entire batch? https://github.com/nasaharvest/presto/blob/main/presto/presto.py#L369 Or do I misunderstand this? Still thinking about the case where in some samples we could miss some values. This results in a batch where not all elements are equally masked. Is this something Presto cannot handle? FYI, Side information: I'm now detecting which timestep/band combinations are no data (e.g. when we don't interpolate missing monthly values in optical) and I'm manually setting corresponding |
Hi @kvantricht , You are right - we currently don't support masking interim timesteps in The reason for this is an artifact of how we export the data (e.g. since we select the least cloudy pixel for Sentinel-2, we should never have a missing pixel, just the least cloudy one). However, I can update this function to make this easier. With respect to the consistent number of masked tokens within a batch, that is correct. The reason for this is that within the model, we end up passing a tensor of shape |
Following up on this, we're trying to create our own I'm a bit confused about this step. It seems to call the In case of the former, I'm not sure how currently the code applies this mask to the entire batch. And in case of the latter, it seems that Would you be able to clarify? |
Every sample in the batch is masked differently. However, this is just the case for pre-training; our finetuning code is available here. This code assumes a static mask for the whole task (of shape |
This is indeed what I thought, so I don't really understand why every call to
Gotcha, but I'm puzzled by things like |
For example, one token is S1 VV and VH (2 values) and one token is S2 RGB (3 values). So if I were to only mask one token, I might sometimes mask 2 values or 3 values. The |
Right, but then it brings me back to the original issue where the
And yes, at the moment we work with the full Presto model. |
Could you share the code snippet you use to generate the masks? Within a batch, the same number of tokens must be passed / masked (this is what the assert statament was designed to check for). As you noted before, using a |
Okay yeah indeed, some example code is more practical to discuss. I made a notebook here which takes a sample WorldCereal dataframe and attempts to transform it into something Presto accepts: https://github.com/kvantricht/presto/blob/worldcereal/notebooks/presto_pretrain_finetune.ipynb As you can see at the end of the notebook, the error messages is about the masks. Hopefully this helps to identify where I'm doing something awfully wrong! |
FYI, when taking |
Okay - I think the cause of the error in the dataloader is that the masking function assumes nothing is masked. However, when using This requires a re-write of the masking function, but shouldn't be too complicated. I can give this a shot in the next few days. |
Not exactly sure what you mean. Indeed, I use |
@kvantricht , I opened a PR into the branch you shared (kvantricht#1) which shows those changes. I've added a description there of exactly the changes I made |
construct_single_presto_input
Amazing Gabi, it runs here now as well with batch size 4096. I will keep you posted on the outcome! |
When using
construct_single_presto_input
, the code conveniently handles the normalization of the inputs and construction of the mask. If certain inputs (bands) are missing, the respective mask values are automatically set to 1. However, there seems to be no way to deal with certain missing timesteps in the inputs. Imagine a monthly compositing of Sentinel-2 resulting in no valid observations for some month. Either we can deal with it by linearly interpolating the missing values, but it seems Presto should actually be able to natively deal with missing timesteps by setting the respectivemask
value to1
.At the moment, the only way to do it is by keeping track of missing value positions in the original inputs and after the call to
construct_single_presto_input
setting the mask of these positions to 1. Would there be a more convenient way of doing this? Thinking of certain no-data values that can be treated by this method as missing and setting the mask in correspondance.Specific side note on automatic computation of NDVI: we were testing with NaN inputs for S2 to see how the code behaves. Interestingly, this line actually makes up a valid NDVI value of 0 in
x
when the inputs are invalid.The text was updated successfully, but these errors were encountered: