1
+ use ruff_text_size:: { TextRange , TextSize } ;
2
+ use rustpython_parser:: ast:: Ranged ;
3
+
4
+ use ruff_formatter:: { format_args, write, Argument , Arguments } ;
5
+
1
6
use crate :: context:: NodeLevel ;
2
7
use crate :: prelude:: * ;
3
- use crate :: trivia:: { first_non_trivia_token, lines_after, skip_trailing_trivia, Token , TokenKind } ;
4
- use ruff_formatter:: { format_args, write, Argument , Arguments } ;
5
- use ruff_text_size:: TextSize ;
6
- use rustpython_parser:: ast:: Ranged ;
8
+ use crate :: trivia:: { lines_after, skip_trailing_trivia, SimpleTokenizer , Token , TokenKind } ;
9
+ use crate :: MagicTrailingComma ;
7
10
8
11
/// Adds parentheses and indents `content` if it doesn't fit on a line.
9
12
pub ( crate ) fn parenthesize_if_expands < ' ast , T > ( content : & T ) -> ParenthesizeIfExpands < ' _ , ' ast >
@@ -53,16 +56,22 @@ pub(crate) trait PyFormatterExtensions<'ast, 'buf> {
53
56
/// A builder that separates each element by a `,` and a [`soft_line_break_or_space`].
54
57
/// It emits a trailing `,` that is only shown if the enclosing group expands. It forces the enclosing
55
58
/// group to expand if the last item has a trailing `comma` and the magical comma option is enabled.
56
- fn join_comma_separated < ' fmt > ( & ' fmt mut self ) -> JoinCommaSeparatedBuilder < ' fmt , ' ast , ' buf > ;
59
+ fn join_comma_separated < ' fmt > (
60
+ & ' fmt mut self ,
61
+ sequence_end : TextSize ,
62
+ ) -> JoinCommaSeparatedBuilder < ' fmt , ' ast , ' buf > ;
57
63
}
58
64
59
65
impl < ' buf , ' ast > PyFormatterExtensions < ' ast , ' buf > for PyFormatter < ' ast , ' buf > {
60
66
fn join_nodes < ' fmt > ( & ' fmt mut self , level : NodeLevel ) -> JoinNodesBuilder < ' fmt , ' ast , ' buf > {
61
67
JoinNodesBuilder :: new ( self , level)
62
68
}
63
69
64
- fn join_comma_separated < ' fmt > ( & ' fmt mut self ) -> JoinCommaSeparatedBuilder < ' fmt , ' ast , ' buf > {
65
- JoinCommaSeparatedBuilder :: new ( self )
70
+ fn join_comma_separated < ' fmt > (
71
+ & ' fmt mut self ,
72
+ sequence_end : TextSize ,
73
+ ) -> JoinCommaSeparatedBuilder < ' fmt , ' ast , ' buf > {
74
+ JoinCommaSeparatedBuilder :: new ( self , sequence_end)
66
75
}
67
76
}
68
77
@@ -194,18 +203,20 @@ pub(crate) struct JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
194
203
result : FormatResult < ( ) > ,
195
204
fmt : & ' fmt mut PyFormatter < ' ast , ' buf > ,
196
205
end_of_last_entry : Option < TextSize > ,
206
+ sequence_end : TextSize ,
197
207
/// We need to track whether we have more than one entry since a sole entry doesn't get a
198
208
/// magic trailing comma even when expanded
199
209
len : usize ,
200
210
}
201
211
202
212
impl < ' fmt , ' ast , ' buf > JoinCommaSeparatedBuilder < ' fmt , ' ast , ' buf > {
203
- fn new ( f : & ' fmt mut PyFormatter < ' ast , ' buf > ) -> Self {
213
+ fn new ( f : & ' fmt mut PyFormatter < ' ast , ' buf > , sequence_end : TextSize ) -> Self {
204
214
Self {
205
215
fmt : f,
206
216
result : Ok ( ( ) ) ,
207
217
end_of_last_entry : None ,
208
218
len : 0 ,
219
+ sequence_end,
209
220
}
210
221
}
211
222
@@ -236,7 +247,7 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
236
247
where
237
248
T : Ranged ,
238
249
F : Format < PyFormatContext < ' ast > > ,
239
- I : Iterator < Item = ( T , F ) > ,
250
+ I : IntoIterator < Item = ( T , F ) > ,
240
251
{
241
252
for ( node, content) in entries {
242
253
self . entry ( & node, & content) ;
@@ -248,7 +259,7 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
248
259
pub ( crate ) fn nodes < ' a , T , I > ( & mut self , entries : I ) -> & mut Self
249
260
where
250
261
T : Ranged + AsFormat < PyFormatContext < ' ast > > + ' a ,
251
- I : Iterator < Item = & ' a T > ,
262
+ I : IntoIterator < Item = & ' a T > ,
252
263
{
253
264
for node in entries {
254
265
self . entry ( node, & node. format ( ) ) ;
@@ -260,14 +271,26 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
260
271
pub ( crate ) fn finish ( & mut self ) -> FormatResult < ( ) > {
261
272
self . result . and_then ( |_| {
262
273
if let Some ( last_end) = self . end_of_last_entry . take ( ) {
263
- let magic_trailing_comma = self . fmt . options ( ) . magic_trailing_comma ( ) . is_respect ( )
264
- && matches ! (
265
- first_non_trivia_token( last_end, self . fmt. context( ) . source( ) ) ,
266
- Some ( Token {
267
- kind: TokenKind :: Comma ,
268
- ..
269
- } )
270
- ) ;
274
+ let magic_trailing_comma = match self . fmt . options ( ) . magic_trailing_comma ( ) {
275
+ MagicTrailingComma :: Respect => {
276
+ let first_token = SimpleTokenizer :: new (
277
+ self . fmt . context ( ) . source ( ) ,
278
+ TextRange :: new ( last_end, self . sequence_end ) ,
279
+ )
280
+ . skip_trivia ( )
281
+ // Skip over any closing parentheses belonging to the expression
282
+ . find ( |token| token. kind ( ) != TokenKind :: RParen ) ;
283
+
284
+ matches ! (
285
+ first_token,
286
+ Some ( Token {
287
+ kind: TokenKind :: Comma ,
288
+ ..
289
+ } )
290
+ )
291
+ }
292
+ MagicTrailingComma :: Ignore => false ,
293
+ } ;
271
294
272
295
// If there is a single entry, only keep the magic trailing comma, don't add it if
273
296
// it wasn't there. If there is more than one entry, always add it.
@@ -287,13 +310,15 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
287
310
288
311
#[ cfg( test) ]
289
312
mod tests {
313
+ use rustpython_parser:: ast:: ModModule ;
314
+ use rustpython_parser:: Parse ;
315
+
316
+ use ruff_formatter:: format;
317
+
290
318
use crate :: comments:: Comments ;
291
319
use crate :: context:: { NodeLevel , PyFormatContext } ;
292
320
use crate :: prelude:: * ;
293
321
use crate :: PyFormatOptions ;
294
- use ruff_formatter:: format;
295
- use rustpython_parser:: ast:: ModModule ;
296
- use rustpython_parser:: Parse ;
297
322
298
323
fn format_ranged ( level : NodeLevel ) -> String {
299
324
let source = r#"
0 commit comments