Skip to content

Commit

Permalink
Implement DurationRound for NaiveDateTime
Browse files Browse the repository at this point in the history
This is off the back of [this
comment](chronotope#445 (comment)).
  • Loading branch information
robyoung authored and pickfire committed Jul 5, 2022
1 parent e6f4385 commit c3eec31
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Versions with only mechanical changes will be omitted from the following list.
* Add more formatting documentation and examples.
* Add support for microseconds timestamps serde serialization/deserialization (#304)
* Fix `DurationRound` is not TZ aware (#495)
* Implement `DurationRound` for `NaiveDateTime`

## 0.4.19

Expand Down
196 changes: 153 additions & 43 deletions src/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use core::fmt;
use core::marker::Sized;
use core::ops::{Add, Sub};
use datetime::DateTime;
use naive::NaiveDateTime;
use oldtime::Duration;
#[cfg(any(feature = "std", test))]
use std;
Expand Down Expand Up @@ -150,56 +151,86 @@ impl<Tz: TimeZone> DurationRound for DateTime<Tz> {
type Err = RoundingError;

fn duration_round(self, duration: Duration) -> Result<Self, Self::Err> {
if let Some(span) = duration.num_nanoseconds() {
let naive = self.naive_local();
duration_round(self.naive_local(), self, duration)
}

if naive.timestamp().abs() > MAX_SECONDS_TIMESTAMP_FOR_NANOS {
return Err(RoundingError::TimestampExceedsLimit);
}
let stamp = naive.timestamp_nanos();
if span > stamp.abs() {
return Err(RoundingError::DurationExceedsTimestamp);
}
let delta_down = stamp % span;
if delta_down == 0 {
Ok(self)
} else {
let (delta_up, delta_down) = if delta_down < 0 {
(delta_down.abs(), span - delta_down.abs())
} else {
(span - delta_down, delta_down)
};
if delta_up <= delta_down {
Ok(self + Duration::nanoseconds(delta_up))
} else {
Ok(self - Duration::nanoseconds(delta_down))
}
}
} else {
Err(RoundingError::DurationExceedsLimit)
}
fn duration_trunc(self, duration: Duration) -> Result<Self, Self::Err> {
duration_trunc(self.naive_local(), self, duration)
}
}

impl DurationRound for NaiveDateTime {
type Err = RoundingError;

fn duration_round(self, duration: Duration) -> Result<Self, Self::Err> {
duration_round(self, self, duration)
}

fn duration_trunc(self, duration: Duration) -> Result<Self, Self::Err> {
if let Some(span) = duration.num_nanoseconds() {
let naive = self.naive_local();
duration_trunc(self, self, duration)
}
}

if naive.timestamp().abs() > MAX_SECONDS_TIMESTAMP_FOR_NANOS {
return Err(RoundingError::TimestampExceedsLimit);
}
let stamp = naive.timestamp_nanos();
if span > stamp.abs() {
return Err(RoundingError::DurationExceedsTimestamp);
}
let delta_down = stamp % span;
match delta_down.cmp(&0) {
Ordering::Equal => Ok(self),
Ordering::Greater => Ok(self - Duration::nanoseconds(delta_down)),
Ordering::Less => Ok(self - Duration::nanoseconds(span - delta_down.abs())),
}
fn duration_round<T>(
naive: NaiveDateTime,
original: T,
duration: Duration,
) -> Result<T, RoundingError>
where
T: Timelike + Add<Duration, Output = T> + Sub<Duration, Output = T>,
{
if let Some(span) = duration.num_nanoseconds() {
if naive.timestamp().abs() > MAX_SECONDS_TIMESTAMP_FOR_NANOS {
return Err(RoundingError::TimestampExceedsLimit);
}
let stamp = naive.timestamp_nanos();
if span > stamp.abs() {
return Err(RoundingError::DurationExceedsTimestamp);
}
let delta_down = stamp % span;
if delta_down == 0 {
Ok(original)
} else {
Err(RoundingError::DurationExceedsLimit)
let (delta_up, delta_down) = if delta_down < 0 {
(delta_down.abs(), span - delta_down.abs())
} else {
(span - delta_down, delta_down)
};
if delta_up <= delta_down {
Ok(original + Duration::nanoseconds(delta_up))
} else {
Ok(original - Duration::nanoseconds(delta_down))
}
}
} else {
Err(RoundingError::DurationExceedsLimit)
}
}

fn duration_trunc<T>(
naive: NaiveDateTime,
original: T,
duration: Duration,
) -> Result<T, RoundingError>
where
T: Timelike + Add<Duration, Output = T> + Sub<Duration, Output = T>,
{
if let Some(span) = duration.num_nanoseconds() {
if naive.timestamp().abs() > MAX_SECONDS_TIMESTAMP_FOR_NANOS {
return Err(RoundingError::TimestampExceedsLimit);
}
let stamp = naive.timestamp_nanos();
if span > stamp.abs() {
return Err(RoundingError::DurationExceedsTimestamp);
}
let delta_down = stamp % span;
match delta_down.cmp(&0) {
Ordering::Equal => Ok(original),
Ordering::Greater => Ok(original - Duration::nanoseconds(delta_down)),
Ordering::Less => Ok(original - Duration::nanoseconds(span - delta_down.abs())),
}
} else {
Err(RoundingError::DurationExceedsLimit)
}
}

Expand Down Expand Up @@ -423,6 +454,46 @@ mod tests {
);
}

#[test]
fn test_duration_round_naive() {
let dt = Utc.ymd(2016, 12, 31).and_hms_nano(23, 59, 59, 175_500_000).naive_utc();

assert_eq!(
dt.duration_round(Duration::milliseconds(10)).unwrap().to_string(),
"2016-12-31 23:59:59.180"
);

// round up
let dt = Utc.ymd(2012, 12, 12).and_hms_milli(18, 22, 30, 0).naive_utc();
assert_eq!(
dt.duration_round(Duration::minutes(5)).unwrap().to_string(),
"2012-12-12 18:25:00"
);
// round down
let dt = Utc.ymd(2012, 12, 12).and_hms_milli(18, 22, 29, 999).naive_utc();
assert_eq!(
dt.duration_round(Duration::minutes(5)).unwrap().to_string(),
"2012-12-12 18:20:00"
);

assert_eq!(
dt.duration_round(Duration::minutes(10)).unwrap().to_string(),
"2012-12-12 18:20:00"
);
assert_eq!(
dt.duration_round(Duration::minutes(30)).unwrap().to_string(),
"2012-12-12 18:30:00"
);
assert_eq!(
dt.duration_round(Duration::hours(1)).unwrap().to_string(),
"2012-12-12 18:00:00"
);
assert_eq!(
dt.duration_round(Duration::days(1)).unwrap().to_string(),
"2012-12-13 00:00:00"
);
}

#[test]
fn test_duration_round_pre_epoch() {
let dt = Utc.ymd(1969, 12, 12).and_hms(12, 12, 12);
Expand Down Expand Up @@ -493,6 +564,45 @@ mod tests {
);
}

#[test]
fn test_duration_trunc_naive() {
let dt = Utc.ymd(2016, 12, 31).and_hms_nano(23, 59, 59, 1_75_500_000).naive_utc();

assert_eq!(
dt.duration_trunc(Duration::milliseconds(10)).unwrap().to_string(),
"2016-12-31 23:59:59.170"
);

// would round up
let dt = Utc.ymd(2012, 12, 12).and_hms_milli(18, 22, 30, 0).naive_utc();
assert_eq!(
dt.duration_trunc(Duration::minutes(5)).unwrap().to_string(),
"2012-12-12 18:20:00"
);
// would round down
let dt = Utc.ymd(2012, 12, 12).and_hms_milli(18, 22, 29, 999).naive_utc();
assert_eq!(
dt.duration_trunc(Duration::minutes(5)).unwrap().to_string(),
"2012-12-12 18:20:00"
);
assert_eq!(
dt.duration_trunc(Duration::minutes(10)).unwrap().to_string(),
"2012-12-12 18:20:00"
);
assert_eq!(
dt.duration_trunc(Duration::minutes(30)).unwrap().to_string(),
"2012-12-12 18:00:00"
);
assert_eq!(
dt.duration_trunc(Duration::hours(1)).unwrap().to_string(),
"2012-12-12 18:00:00"
);
assert_eq!(
dt.duration_trunc(Duration::days(1)).unwrap().to_string(),
"2012-12-12 00:00:00"
);
}

#[test]
fn test_duration_trunc_pre_epoch() {
let dt = Utc.ymd(1969, 12, 12).and_hms(12, 12, 12);
Expand Down

0 comments on commit c3eec31

Please sign in to comment.