@@ -122,8 +122,8 @@ def format_label(self, label, name=None):
122
122
"BYTES" : types .BINARY ,
123
123
"TIME" : types .TIME ,
124
124
"RECORD" : types .JSON ,
125
- "NUMERIC" : types .DECIMAL ,
126
- "BIGNUMERIC" : types .DECIMAL ,
125
+ "NUMERIC" : types .Numeric ,
126
+ "BIGNUMERIC" : types .Numeric ,
127
127
}
128
128
129
129
STRING = _type_map ["STRING" ]
@@ -158,23 +158,33 @@ def get_insert_default(self, column): # pragma: NO COVER
158
158
elif isinstance (column .type , String ):
159
159
return str (uuid .uuid4 ())
160
160
161
- __remove_type_from_empty_in = _helpers .substitute_re_method (
162
- r" IN UNNEST\(\[ ("
163
- r"(?:NULL|\(NULL(?:, NULL)+\))\)"
164
- r" (?:AND|OR) \(1 !?= 1"
165
- r")"
166
- r"(?:[:][A-Z0-9]+)?"
167
- r" \]\)" ,
168
- re .IGNORECASE ,
169
- r" IN(\1)" ,
161
+ __remove_type_from_empty_in = _helpers .substitute_string_re_method (
162
+ r"""
163
+ \sIN\sUNNEST\(\[\s # ' IN UNNEST([ '
164
+ (
165
+ (?:NULL|\(NULL(?:,\sNULL)+\))\) # '(NULL)' or '((NULL, NULL, ...))'
166
+ \s(?:AND|OR)\s\(1\s!?=\s1 # ' and 1 != 1' or ' or 1 = 1'
167
+ )
168
+ (?:[:][A-Z0-9]+)? # Maybe ':TYPE' (e.g. ':INT64')
169
+ \s\]\) # Close: ' ])'
170
+ """ ,
171
+ flags = re .IGNORECASE | re .VERBOSE ,
172
+ repl = r" IN(\1)" ,
170
173
)
171
174
172
175
@_helpers .substitute_re_method (
173
- r" IN UNNEST\(\[ "
174
- r"(%\([^)]+_\d+\)s(?:, %\([^)]+_\d+\)s)*)?" # Placeholders. See below.
175
- r":([A-Z0-9]+)" # Type
176
- r" \]\)" ,
177
- re .IGNORECASE ,
176
+ r"""
177
+ \sIN\sUNNEST\(\[\s # ' IN UNNEST([ '
178
+ ( # Placeholders. See below.
179
+ %\([^)]+_\d+\)s # Placeholder '%(foo_1)s'
180
+ (?:,\s # 0 or more placeholders
181
+ %\([^)]+_\d+\)s
182
+ )*
183
+ )?
184
+ :([A-Z0-9]+) # Type ':TYPE' (e.g. ':INT64')
185
+ \s\]\) # Close: ' ])'
186
+ """ ,
187
+ flags = re .IGNORECASE | re .VERBOSE ,
178
188
)
179
189
def __distribute_types_to_expanded_placeholders (self , m ):
180
190
# If we have an in parameter, it sometimes gets expaned to 0 or more
@@ -282,10 +292,20 @@ def group_by_clause(self, select, **kw):
282
292
"EXPANDING" if __sqlalchemy_version_info < (1 , 4 ) else "POSTCOMPILE"
283
293
)
284
294
285
- __in_expanding_bind = _helpers .substitute_re_method (
286
- fr" IN \((\[" fr"{ __expandng_text } " fr"_[^\]]+\](:[A-Z0-9]+)?)\)$" ,
287
- re .IGNORECASE ,
288
- r" IN UNNEST([ \1 ])" ,
295
+ __in_expanding_bind = _helpers .substitute_string_re_method (
296
+ fr"""
297
+ \sIN\s\( # ' IN ('
298
+ (
299
+ \[ # Expanding placeholder
300
+ { __expandng_text } # e.g. [EXPANDING_foo_1]
301
+ _[^\]]+ #
302
+ \]
303
+ (:[A-Z0-9]+)? # type marker (e.g. ':INT64'
304
+ )
305
+ \)$ # close w ending )
306
+ """ ,
307
+ flags = re .IGNORECASE | re .VERBOSE ,
308
+ repl = r" IN UNNEST([ \1 ])" ,
289
309
)
290
310
291
311
def visit_in_op_binary (self , binary , operator_ , ** kw ):
@@ -360,6 +380,18 @@ def visit_notendswith_op_binary(self, binary, operator, **kw):
360
380
361
381
__expanded_param = re .compile (fr"\(\[" fr"{ __expandng_text } " fr"_[^\]]+\]\)$" ).match
362
382
383
+ __remove_type_parameter = _helpers .substitute_string_re_method (
384
+ r"""
385
+ (STRING|BYTES|NUMERIC|BIGNUMERIC) # Base type
386
+ \( # Dimensions e.g. '(42)', '(4, 2)':
387
+ \s*\d+\s* # First dimension
388
+ (?:,\s*\d+\s*)* # Remaining dimensions
389
+ \)
390
+ """ ,
391
+ repl = r"\1" ,
392
+ flags = re .VERBOSE | re .IGNORECASE ,
393
+ )
394
+
363
395
def visit_bindparam (
364
396
self ,
365
397
bindparam ,
@@ -397,6 +429,7 @@ def visit_bindparam(
397
429
if bq_type [- 1 ] == ">" and bq_type .startswith ("ARRAY<" ):
398
430
# Values get arrayified at a lower level.
399
431
bq_type = bq_type [6 :- 1 ]
432
+ bq_type = self .__remove_type_parameter (bq_type )
400
433
401
434
assert_ (param != "%s" , f"Unexpected param: { param } " )
402
435
@@ -429,6 +462,10 @@ def visit_FLOAT(self, type_, **kw):
429
462
visit_REAL = visit_FLOAT
430
463
431
464
def visit_STRING (self , type_ , ** kw ):
465
+ if (type_ .length is not None ) and isinstance (
466
+ kw .get ("type_expression" ), Column
467
+ ): # column def
468
+ return f"STRING({ type_ .length } )"
432
469
return "STRING"
433
470
434
471
visit_CHAR = visit_NCHAR = visit_STRING
@@ -438,17 +475,29 @@ def visit_ARRAY(self, type_, **kw):
438
475
return "ARRAY<{}>" .format (self .process (type_ .item_type , ** kw ))
439
476
440
477
def visit_BINARY (self , type_ , ** kw ):
478
+ if type_ .length is not None :
479
+ return f"BYTES({ type_ .length } )"
441
480
return "BYTES"
442
481
443
482
visit_VARBINARY = visit_BINARY
444
483
445
484
def visit_NUMERIC (self , type_ , ** kw ):
446
- if (type_ .precision is not None and type_ .precision > 38 ) or (
447
- type_ .scale is not None and type_ .scale > 9
448
- ):
449
- return "BIGNUMERIC"
485
+ if (type_ .precision is not None ) and isinstance (
486
+ kw .get ("type_expression" ), Column
487
+ ): # column def
488
+ if type_ .scale is not None :
489
+ suffix = f"({ type_ .precision } , { type_ .scale } )"
490
+ else :
491
+ suffix = f"({ type_ .precision } )"
450
492
else :
451
- return "NUMERIC"
493
+ suffix = ""
494
+
495
+ return (
496
+ "BIGNUMERIC"
497
+ if (type_ .precision is not None and type_ .precision > 38 )
498
+ or (type_ .scale is not None and type_ .scale > 9 )
499
+ else "NUMERIC"
500
+ ) + suffix
452
501
453
502
visit_DECIMAL = visit_NUMERIC
454
503
@@ -800,18 +849,16 @@ def _get_columns_helper(self, columns, cur_columns):
800
849
"""
801
850
results = []
802
851
for col in columns :
803
- results += [
804
- SchemaField (
805
- name = "." .join (col .name for col in cur_columns + [col ]),
806
- field_type = col .field_type ,
807
- mode = col .mode ,
808
- description = col .description ,
809
- fields = col .fields ,
810
- )
811
- ]
852
+ results += [col ]
812
853
if col .field_type == "RECORD" :
813
854
cur_columns .append (col )
814
- results += self ._get_columns_helper (col .fields , cur_columns )
855
+ fields = [
856
+ SchemaField .from_api_repr (
857
+ dict (f .to_api_repr (), name = f"{ col .name } .{ f .name } " )
858
+ )
859
+ for f in col .fields
860
+ ]
861
+ results += self ._get_columns_helper (fields , cur_columns )
815
862
cur_columns .pop ()
816
863
return results
817
864
@@ -829,13 +876,21 @@ def get_columns(self, connection, table_name, schema=None, **kw):
829
876
)
830
877
coltype = types .NullType
831
878
879
+ if col .field_type .endswith ("NUMERIC" ):
880
+ coltype = coltype (precision = col .precision , scale = col .scale )
881
+ elif col .field_type == "STRING" or col .field_type == "BYTES" :
882
+ coltype = coltype (col .max_length )
883
+
832
884
result .append (
833
885
{
834
886
"name" : col .name ,
835
887
"type" : types .ARRAY (coltype ) if col .mode == "REPEATED" else coltype ,
836
888
"nullable" : col .mode == "NULLABLE" or col .mode == "REPEATED" ,
837
889
"comment" : col .description ,
838
890
"default" : None ,
891
+ "precision" : col .precision ,
892
+ "scale" : col .scale ,
893
+ "max_length" : col .max_length ,
839
894
}
840
895
)
841
896
0 commit comments