Skip to content

Commit 5915aa8

Browse files
committed
Add support for overloading derived field functions
1 parent fa1297e commit 5915aa8

File tree

8 files changed

+157
-30
lines changed

8 files changed

+157
-30
lines changed

book/src/field_functions.md

+19
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ are:
4646
| [`Protocol::BIT_XOR_ASSIGN`] | `#[rune(bit_xor_assign)]` | The `^=` operation. |
4747
| [`Protocol::SHL_ASSIGN`] | `#[rune(shl_assign)]` | The `<<=` operation. |
4848
| [`Protocol::SHR_ASSIGN`] | `#[rune(shr_assign)]` | The `>>=` operation. |
49+
| [`Protocol::REM_ASSIGN`] | `#[rune(rem_assign)]` | The `%=` operation. |
4950

5051
The manual way to register these functions is to use the new `Module::field_fn`
5152
function. This clearly showcases that there's no relationship between the field
@@ -76,6 +77,23 @@ pub fn main(external) {
7677
}
7778
```
7879

80+
## Custom field function
81+
82+
Using the `Any` derive, you can specify a custom field function by using an
83+
argument to the corresponding attribute pointing to the function to use instead.
84+
85+
The following uses an implementation of `add_assign` which performs checked
86+
addition:
87+
88+
```rust,noplaypen
89+
{{#include ../../crates/rune/examples/checked_add_assign.rs}}
90+
```
91+
92+
```text
93+
$> cargo run --example checked_add_assign
94+
Error: numerical overflow (at inst 2)
95+
```
96+
7997
[`Protocol::GET`]: https://docs.rs/runestick/0/runestick/struct.Protocol.html#associatedconstant.GET
8098
[`Protocol::SET`]: https://docs.rs/runestick/0/runestick/struct.Protocol.html#associatedconstant.SET
8199
[`Protocol::ADD_ASSIGN`]: https://docs.rs/runestick/0/runestick/struct.Protocol.html#associatedconstant.ADD_ASSIGN
@@ -87,3 +105,4 @@ pub fn main(external) {
87105
[`Protocol::BIT_XOR_ASSIGN`]: https://docs.rs/runestick/0/runestick/struct.Protocol.html#associatedconstant.BIT_XOR_ASSIGN
88106
[`Protocol::SHL_ASSIGN`]: https://docs.rs/runestick/0/runestick/struct.Protocol.html#associatedconstant.SHL_ASSIGN
89107
[`Protocol::SHR_ASSIGN`]: https://docs.rs/runestick/0/runestick/struct.Protocol.html#associatedconstant.SHR_ASSIGN
108+
[`Protocol::REM_ASSIGN`]: https://docs.rs/runestick/0/runestick/struct.Protocol.html#associatedconstant.REM_ASSIGN
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
use runestick::{Any, Context, Module, VmError, VmErrorKind};
2+
use std::sync::Arc;
3+
4+
#[derive(Any)]
5+
struct External {
6+
#[rune(add_assign = "External::value_add_assign")]
7+
value: i64,
8+
}
9+
10+
impl External {
11+
fn value_add_assign(&mut self, other: i64) -> Result<(), VmError> {
12+
self.value = self
13+
.value
14+
.checked_add(other)
15+
.ok_or_else(|| VmErrorKind::Overflow)?;
16+
17+
Ok(())
18+
}
19+
}
20+
21+
fn main() -> runestick::Result<()> {
22+
let mut module = Module::default();
23+
module.ty::<External>()?;
24+
25+
let mut context = Context::default();
26+
context.install(&module)?;
27+
let context = Arc::new(context);
28+
29+
let external = External {
30+
value: i64::max_value(),
31+
};
32+
33+
let result = rune::testing::run::<_, _, ()>(
34+
&context,
35+
&["main"],
36+
(external,),
37+
"pub fn main(external) { external.value += 1; }",
38+
);
39+
40+
let error = result.expect_err("expected error");
41+
let error = error.expect_vm_error("expected vm error");
42+
println!("Error: {}", error);
43+
Ok(())
44+
}

crates/rune/src/testing.rs

+10
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ pub enum RunError {
2929
VmError(#[source] VmError),
3030
}
3131

32+
impl RunError {
33+
/// Unpack into a vm error or panic with the given message.
34+
pub fn expect_vm_error(self, msg: &str) -> VmError {
35+
match self {
36+
Self::VmError(error) => error,
37+
_ => panic!("{}", msg),
38+
}
39+
}
40+
}
41+
3242
/// Compile the given source into a unit and collection of warnings.
3343
pub fn compile_source(
3444
context: &runestick::Context,

crates/rune/tests/test_all/external_ops.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ fn test_external_ops() {
1313
field: i64,
1414
#[rune($derived)]
1515
derived: i64,
16+
#[rune($derived = "External::custom")]
17+
custom: i64,
1618
}
1719

1820
impl External {
@@ -23,6 +25,10 @@ fn test_external_ops() {
2325
fn field(&mut self, value: i64) {
2426
self.field $($op)* value;
2527
}
28+
29+
fn custom(&mut self, value: i64) {
30+
self.custom $($op)* value;
31+
}
2632
}
2733

2834
let mut module = Module::empty();
@@ -49,6 +55,7 @@ fn test_external_ops() {
4955
number {op} {arg};
5056
number.field {op} {arg};
5157
number.derived {op} {arg};
58+
number.custom {op} {arg};
5259
}}
5360
"#, op = stringify!($($op)*), arg = stringify!($arg)),
5461
));
@@ -72,12 +79,14 @@ fn test_external_ops() {
7279
foo.value = $initial;
7380
foo.field = $initial;
7481
foo.derived = $initial;
82+
foo.custom = $initial;
7583

7684
let output = vm.clone().call(&["type"], (&mut foo,)).unwrap();
7785

78-
assert_eq!(foo.value, $expected);
79-
assert_eq!(foo.field, $expected);
80-
assert_eq!(foo.derived, $expected);
86+
assert_eq!(foo.value, $expected, "{} != {} (value)", foo.value, $expected);
87+
assert_eq!(foo.field, $expected, "{} != {} (field)", foo.value, $expected);
88+
assert_eq!(foo.derived, $expected, "{} != {} (derived)", foo.value, $expected);
89+
assert_eq!(foo.custom, $expected, "{} != {} (custom)", foo.value, $expected);
8190
assert!(matches!(output, Value::Unit));
8291
}
8392
}};
@@ -92,4 +101,5 @@ fn test_external_ops() {
92101
test_case!([^=], BIT_XOR_ASSIGN, bit_xor_assign, 0b1001, 0b0011, 0b1010);
93102
test_case!([<<=], SHL_ASSIGN, shl_assign, 0b1001, 0b0001, 0b10010);
94103
test_case!([>>=], SHR_ASSIGN, shr_assign, 0b1001, 0b0001, 0b100);
104+
test_case!([%=], REM_ASSIGN, rem_assign, 25, 10, 5);
95105
}

crates/runestick-macros/src/context.rs

+67-17
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use syn::NestedMeta::*;
1111
struct Generate<'a> {
1212
context: &'a Context,
1313
attrs: &'a FieldAttrs,
14+
protocol: &'a FieldProtocol,
1415
ident: &'a syn::Ident,
1516
field: &'a syn::Field,
1617
field_ident: &'a syn::Ident,
@@ -20,6 +21,7 @@ struct Generate<'a> {
2021

2122
pub(crate) struct FieldProtocol {
2223
generate: fn(Generate<'_>) -> TokenStream,
24+
custom: Option<syn::Path>,
2325
}
2426

2527
/// Parsed field attributes.
@@ -148,10 +150,16 @@ impl Context {
148150

149151
let protocol = g.context.protocol($proto);
150152

151-
quote_spanned! { g.field.span() =>
152-
module.field_fn(#protocol, #name, |s: &mut #ident, value: #ty| {
153-
s.#field_ident $op value;
154-
})?;
153+
if let Some(custom) = &g.protocol.custom {
154+
quote_spanned! { g.field.span() =>
155+
module.field_fn(#protocol, #name, #custom)?;
156+
}
157+
} else {
158+
quote_spanned! { g.field.span() =>
159+
module.field_fn(#protocol, #name, |s: &mut #ident, value: #ty| {
160+
s.#field_ident $op value;
161+
})?;
162+
}
155163
}
156164
}
157165
};
@@ -162,8 +170,12 @@ impl Context {
162170
for attr in attrs {
163171
for meta in self.get_rune_meta_items(attr)? {
164172
match meta {
165-
Meta(Path(path)) if path == GET => {
173+
Meta(Path(path)) if path == COPY => {
174+
output.copy = true;
175+
}
176+
Meta(meta) if meta.path() == GET => {
166177
output.protocols.push(FieldProtocol {
178+
custom: self.parse_field_custom(meta)?,
167179
generate: |g| {
168180
let Generate {
169181
ident,
@@ -186,8 +198,9 @@ impl Context {
186198
},
187199
});
188200
}
189-
Meta(Path(path)) if path == SET => {
201+
Meta(meta) if meta.path() == SET => {
190202
output.protocols.push(FieldProtocol {
203+
custom: self.parse_field_custom(meta)?,
191204
generate: |g| {
192205
let Generate {
193206
ident,
@@ -207,53 +220,65 @@ impl Context {
207220
},
208221
});
209222
}
210-
Meta(Path(path)) if path == ADD_ASSIGN => {
223+
Meta(meta) if meta.path() == ADD_ASSIGN => {
211224
output.protocols.push(FieldProtocol {
225+
custom: self.parse_field_custom(meta)?,
212226
generate: generate_op!(PROTOCOL_ADD_ASSIGN, +=),
213227
});
214228
}
215-
Meta(Path(path)) if path == SUB_ASSIGN => {
229+
Meta(meta) if meta.path() == SUB_ASSIGN => {
216230
output.protocols.push(FieldProtocol {
231+
custom: self.parse_field_custom(meta)?,
217232
generate: generate_op!(PROTOCOL_SUB_ASSIGN, -=),
218233
});
219234
}
220-
Meta(Path(path)) if path == DIV_ASSIGN => {
235+
Meta(meta) if meta.path() == DIV_ASSIGN => {
221236
output.protocols.push(FieldProtocol {
237+
custom: self.parse_field_custom(meta)?,
222238
generate: generate_op!(PROTOCOL_DIV_ASSIGN, /=),
223239
});
224240
}
225-
Meta(Path(path)) if path == MUL_ASSIGN => {
241+
Meta(meta) if meta.path() == MUL_ASSIGN => {
226242
output.protocols.push(FieldProtocol {
243+
custom: self.parse_field_custom(meta)?,
227244
generate: generate_op!(PROTOCOL_MUL_ASSIGN, *=),
228245
});
229246
}
230-
Meta(Path(path)) if path == BIT_AND_ASSIGN => {
247+
Meta(meta) if meta.path() == BIT_AND_ASSIGN => {
231248
output.protocols.push(FieldProtocol {
249+
custom: self.parse_field_custom(meta)?,
232250
generate: generate_op!(PROTOCOL_BIT_AND_ASSIGN, &=),
233251
});
234252
}
235-
Meta(Path(path)) if path == BIT_OR_ASSIGN => {
253+
Meta(meta) if meta.path() == BIT_OR_ASSIGN => {
236254
output.protocols.push(FieldProtocol {
255+
custom: self.parse_field_custom(meta)?,
237256
generate: generate_op!(PROTOCOL_BIT_OR_ASSIGN, |=),
238257
});
239258
}
240-
Meta(Path(path)) if path == BIT_XOR_ASSIGN => {
259+
Meta(meta) if meta.path() == BIT_XOR_ASSIGN => {
241260
output.protocols.push(FieldProtocol {
261+
custom: self.parse_field_custom(meta)?,
242262
generate: generate_op!(PROTOCOL_BIT_XOR_ASSIGN, ^=),
243263
});
244264
}
245-
Meta(Path(path)) if path == SHL_ASSIGN => {
265+
Meta(meta) if meta.path() == SHL_ASSIGN => {
246266
output.protocols.push(FieldProtocol {
267+
custom: self.parse_field_custom(meta)?,
247268
generate: generate_op!(PROTOCOL_SHL_ASSIGN, <<=),
248269
});
249270
}
250-
Meta(Path(path)) if path == SHR_ASSIGN => {
271+
Meta(meta) if meta.path() == SHR_ASSIGN => {
251272
output.protocols.push(FieldProtocol {
273+
custom: self.parse_field_custom(meta)?,
252274
generate: generate_op!(PROTOCOL_SHR_ASSIGN, >>=),
253275
});
254276
}
255-
Meta(Path(path)) if path == COPY => {
256-
output.copy = true;
277+
Meta(meta) if meta.path() == REM_ASSIGN => {
278+
output.protocols.push(FieldProtocol {
279+
custom: self.parse_field_custom(meta)?,
280+
generate: generate_op!(PROTOCOL_REM_ASSIGN, %=),
281+
});
257282
}
258283
_ => {
259284
self.errors
@@ -268,6 +293,30 @@ impl Context {
268293
Some(output)
269294
}
270295

296+
/// Parse path to custom field function.
297+
fn parse_field_custom(&mut self, meta: syn::Meta) -> Option<Option<syn::Path>> {
298+
let s = match meta {
299+
Path(..) => return Some(None),
300+
NameValue(syn::MetaNameValue {
301+
lit: syn::Lit::Str(s),
302+
..
303+
}) => s,
304+
_ => {
305+
self.errors
306+
.push(syn::Error::new(meta.span(), "unsupported meta"));
307+
return None;
308+
}
309+
};
310+
311+
match s.parse_with(syn::Path::parse_mod_style) {
312+
Ok(path) => Some(Some(path)),
313+
Err(error) => {
314+
self.errors.push(error);
315+
None
316+
}
317+
}
318+
}
319+
271320
/// Parse field attributes.
272321
pub(crate) fn parse_derive_attrs(&mut self, attrs: &[syn::Attribute]) -> Option<DeriveAttrs> {
273322
let mut output = DeriveAttrs::default();
@@ -328,6 +377,7 @@ impl Context {
328377
for protocol in &attrs.protocols {
329378
installers.push((protocol.generate)(Generate {
330379
context: self,
380+
protocol,
331381
attrs: &attrs,
332382
ident,
333383
field,

crates/runestick-macros/src/internals.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub const NAME: Symbol = Symbol("name");
1111

1212
pub const GET: Symbol = Symbol("get");
1313
pub const SET: Symbol = Symbol("set");
14+
pub const COPY: Symbol = Symbol("copy");
1415

1516
pub const ADD_ASSIGN: Symbol = Symbol("add_assign");
1617
pub const SUB_ASSIGN: Symbol = Symbol("sub_assign");
@@ -21,8 +22,7 @@ pub const BIT_OR_ASSIGN: Symbol = Symbol("bit_or_assign");
2122
pub const BIT_XOR_ASSIGN: Symbol = Symbol("bit_xor_assign");
2223
pub const SHL_ASSIGN: Symbol = Symbol("shl_assign");
2324
pub const SHR_ASSIGN: Symbol = Symbol("shr_assign");
24-
25-
pub const COPY: Symbol = Symbol("copy");
25+
pub const REM_ASSIGN: Symbol = Symbol("rem_assign");
2626

2727
pub const PROTOCOL_GET: Symbol = Symbol("GET");
2828
pub const PROTOCOL_SET: Symbol = Symbol("SET");
@@ -35,6 +35,7 @@ pub const PROTOCOL_BIT_OR_ASSIGN: Symbol = Symbol("BIT_OR_ASSIGN");
3535
pub const PROTOCOL_BIT_XOR_ASSIGN: Symbol = Symbol("BIT_XOR_ASSIGN");
3636
pub const PROTOCOL_SHL_ASSIGN: Symbol = Symbol("SHL_ASSIGN");
3737
pub const PROTOCOL_SHR_ASSIGN: Symbol = Symbol("SHR_ASSIGN");
38+
pub const PROTOCOL_REM_ASSIGN: Symbol = Symbol("REM_ASSIGN");
3839

3940
impl PartialEq<Symbol> for syn::Ident {
4041
fn eq(&self, word: &Symbol) -> bool {

crates/runestick/src/module.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -967,9 +967,7 @@ macro_rules! impl_register {
967967
(@return $stack:ident, $ret:ident, $ty:ty) => {
968968
let $ret = match $ret.to_value() {
969969
Ok($ret) => $ret,
970-
Err(e) => return Err(VmError::from(VmErrorKind::BadReturn {
971-
error: e.unpack_critical()?,
972-
})),
970+
Err(e) => return Err(VmError::from(e.unpack_critical()?)),
973971
};
974972

975973
$stack.push($ret);

crates/runestick/src/vm_error.rs

-5
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,6 @@ pub enum VmErrorKind {
215215
error: VmError,
216216
arg: usize,
217217
},
218-
#[error("bad return value: {error}")]
219-
BadReturn {
220-
#[source]
221-
error: VmError,
222-
},
223218
#[error("the index set operation `{target}[{index}] = {value}` is not supported")]
224219
UnsupportedIndexSet {
225220
target: TypeInfo,

0 commit comments

Comments
 (0)