Skip to content
This repository was archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
resolve merge issues
Browse files Browse the repository at this point in the history
  • Loading branch information
galengorski committed Aug 5, 2022
2 parents 9c0741e + ba69d4c commit 948ea80
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 28 deletions.
2 changes: 2 additions & 0 deletions 03b_model/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ inputs: ['discharge_01463500',
'wind_direction_delsjmet',
'wind_speed_delsjmet']#,'wind_speed_direction_delsjmet']


#the sinks to be analyzed, this will likely be the salt front location
target: 'saltfront7_weekly'
replicates: 5
Expand Down Expand Up @@ -45,3 +46,4 @@ dropout: 0.1
n_epochs: 500
learn_rate: 0.001


49 changes: 24 additions & 25 deletions 03b_model/src/LSTMDA_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,35 +99,34 @@ def mse_masked(y_true, y_pred):
mse_loss = sum_squared_errors / num_y_true
return mse_loss

def extreme_loss(y_true, y_pred):
y_true_or_err = torch.where(
torch.isnan(y_true), torch.Tensor([-9999]) , y_true
)
#-0.0138 is approx river mile 70 scaled by mean and sd of the training set
low_zero_or_error = torch.where(
(y_true_or_err >= -0.0138) | (y_true_or_err < -999), torch.zeros_like(y_true_or_err), y_pred - y_true_or_err
)
num_low_y_true = torch.count_nonzero(
low_zero_or_error
)
high_zero_or_error = torch.where(
(y_true_or_err < -0.0138) , torch.zeros_like(y_true_or_err), y_pred - y_true_or_err
)
num_high_y_true = torch.count_nonzero(
high_zero_or_error
)
# def extreme_loss(y_true, y_pred):
# y_true_or_err = torch.where(
# torch.isnan(y_true), torch.Tensor([-9999]) , y_true
# )
# #-0.0138 is approx river mile 70 scaled by mean and sd of the training set
# low_zero_or_error = torch.where(
# (y_true_or_err >= -0.0138) | (y_true_or_err < -999), torch.zeros_like(y_true_or_err), y_pred - y_true_or_err
# )
# num_low_y_true = torch.count_nonzero(
# low_zero_or_error
# )
# high_zero_or_error = torch.where(
# (y_true_or_err < -0.0138) , torch.zeros_like(y_true_or_err), y_pred - y_true_or_err
# )
# num_high_y_true = torch.count_nonzero(
# high_zero_or_error
# )

sum_squared_errors_low = torch.sum(torch.square(low_zero_or_error))
loss_low = torch.sqrt(sum_squared_errors_low / num_low_y_true)

#cube the errors where river mile is high
sum_cubed_errors_high = torch.sum(torch.pow(high_zero_or_error,4))
loss_high = (sum_cubed_errors_high / num_high_y_true)
# sum_squared_errors_low = torch.sum(torch.square(low_zero_or_error))
# loss_low = torch.sqrt(sum_squared_errors_low / num_low_y_true)

loss_hi_low = loss_low.add(loss_high)
# #cube the errors where river mile is high
# sum_cubed_errors_high = torch.sum(torch.pow(high_zero_or_error,4))
# loss_high = (sum_cubed_errors_high / num_high_y_true)

return loss_hi_low
# loss_hi_low = loss_low.add(loss_high)

# return loss_hi_low


# def rmse_weighted(y_true, y_pred): # weighted by covariance matrix from DA; weights are concatonated onto y_true and need to separate out within function
Expand Down
9 changes: 9 additions & 0 deletions 03b_model/src/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,13 @@ def select_inputs_targets(inputs, target, train_start_date, test_end_date, out_d
# else:
# mask = target_df_c['saltfront_daily'] < 54
# target_df_c.loc[mask,'saltfront_daily'] = np.nan
<<<<<<< HEAD
mask = target_df_c['saltfront7_weekly'] < 54
target_df_c.loc[mask,'saltfront7_weekly'] = np.nan
=======
mask = target_df_c['saltfront_daily'] < 54
target_df_c.loc[mask,'saltfront_daily'] = np.nan
>>>>>>> ba69d4c077cc19a40e3c4275ae19fec7a9fc1163

inputs_xarray = inputs_df.to_xarray()
target_xarray = target_df_c.to_xarray()
Expand Down Expand Up @@ -359,7 +364,10 @@ def write_model_params(out_dir, run_id, inputs, n_epochs,
f = open(os.path.join(dir,"model_param_output.txt"),"w+")
f.write("Date: %s\r\n" % date.today().strftime("%b-%d-%Y"))
f.write("Feature List: %s\r\n" % inputs_log)
<<<<<<< HEAD
f.write("Target: %s\r\n" % target)
=======
>>>>>>> ba69d4c077cc19a40e3c4275ae19fec7a9fc1163
f.write("Include antecedant variable: %s\r\n" % inc_ante)
f.write("Epochs: %d\r\n" % n_epochs)
f.write("Learning rate: %f\r\n" % learn_rate)
Expand Down Expand Up @@ -617,6 +625,7 @@ def run_replicates(n_reps, prepped_model_io_data_file):
test_start_date, test_end_date)

plot_save_predictions(predictions, out_dir, run_id)



def test_hyperparameters():
Expand Down
3 changes: 0 additions & 3 deletions Snakefile_fetch_munge
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,3 @@ rule munge_noaa_nerrs:
munge = importlib.import_module('02_munge.src.munge_noaa_nerrs')
munge.munge_single_site_data('delsjmet')




4 changes: 4 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ dependencies:
- pypi_install
- sciencebasepy
- scikit-learn




0 comments on commit 948ea80

Please sign in to comment.