@@ -410,6 +410,31 @@ impl PgReplicationClient {
410410 & self ,
411411 publication_name : & str ,
412412 ) -> EtlResult < Vec < TableId > > {
413+ // Prefer pg_publication_rel (explicit tables in the publication, including partition roots)
414+ let rel_query = format ! (
415+ r#"select r.prrelid as oid
416+ from pg_publication_rel r
417+ join pg_publication p on p.oid = r.prpubid
418+ where p.pubname = {}"# ,
419+ quote_literal( publication_name)
420+ ) ;
421+
422+ let mut table_ids = vec ! [ ] ;
423+ let mut has_rows = false ;
424+ for msg in self . client . simple_query ( & rel_query) . await ? {
425+ if let SimpleQueryMessage :: Row ( row) = msg {
426+ has_rows = true ;
427+ let table_id =
428+ Self :: get_row_value :: < TableId > ( & row, "oid" , "pg_publication_rel" ) . await ?;
429+ table_ids. push ( table_id) ;
430+ }
431+ }
432+
433+ if has_rows {
434+ return Ok ( table_ids) ;
435+ }
436+
437+ // Fallback to pg_publication_tables (expanded view), used for publications like FOR ALL TABLES
413438 let publication_query = format ! (
414439 "select c.oid from pg_publication_tables pt
415440 join pg_class c on c.relname = pt.tablename
@@ -418,10 +443,8 @@ impl PgReplicationClient {
418443 quote_literal( publication_name)
419444 ) ;
420445
421- let mut table_ids = vec ! [ ] ;
422446 for msg in self . client . simple_query ( & publication_query) . await ? {
423447 if let SimpleQueryMessage :: Row ( row) = msg {
424- // For the sake of simplicity, we refer to the table oid as table id.
425448 let table_id = Self :: get_row_value :: < TableId > ( & row, "oid" , "pg_class" ) . await ?;
426449 table_ids. push ( table_id) ;
427450 }
@@ -721,44 +744,19 @@ impl PgReplicationClient {
721744 join direct_parent dp on con.conrelid = dp.parent_oid
722745 where con.contype = 'p'
723746 group by con.conname
724- ),
725- -- Check if current table has a unique index on the parent PK columns
726- partition_has_pk_index as (
727- select case
728- when exists (select 1 from direct_parent)
729- and exists (select 1 from parent_pk_cols)
730- and exists (
731- -- Check if there's a unique, valid index on the parent PK columns
732- select 1
733- from pg_index ix
734- cross join parent_pk_cols pk
735- where ix.indrelid = {table_id}::oid
736- and ix.indisunique = true
737- and ix.indisvalid = true
738- and array(
739- select a.attname
740- from unnest(ix.indkey) with ordinality k(attnum, ord)
741- join pg_attribute a on a.attrelid = ix.indrelid and a.attnum = k.attnum
742- where ord <= ix.indnkeyatts -- exclude INCLUDE columns
743- order by ord
744- ) = pk.pk_column_names
745- ) then true
746- else false
747- end as has_inherited_pk
748747 )
749748 SELECT a.attname,
750749 a.atttypid,
751750 a.atttypmod,
752751 a.attnotnull,
753752 case
754- -- First check for direct primary key
753+ -- Direct primary key on this relation
755754 when coalesce(i.indisprimary, false) = true then true
756- -- Then check for inherited primary key from partitioned table parent
757- when (select has_inherited_pk from partition_has_pk_index) = true
758- and exists (
759- select 1 from parent_pk_cols pk
760- where a.attname = any(pk.pk_column_names)
761- ) then true
755+ -- Inherit primary key from parent partitioned table if column name matches
756+ when exists (
757+ select 1 from parent_pk_cols pk
758+ where a.attname = any(pk.pk_column_names)
759+ ) then true
762760 else false
763761 end as primary
764762 from pg_attribute a
@@ -816,17 +814,34 @@ impl PgReplicationClient {
816814 . collect :: < Vec < _ > > ( )
817815 . join ( ", " ) ;
818816
819- let table_name = self . get_table_name ( table_id) . await ?;
817+ let copy_query = if self . is_partitioned_table ( table_id) . await ?
818+ && let leaf_partitions = self . get_leaf_partition_ids ( table_id) . await ?
819+ && !leaf_partitions. is_empty ( )
820+ {
821+ let mut selects = Vec :: with_capacity ( leaf_partitions. len ( ) ) ;
822+ for child_id in leaf_partitions {
823+ let child_name = self . get_table_name ( child_id) . await ?;
824+ let select = format ! (
825+ "select {} from {}" ,
826+ column_list,
827+ child_name. as_quoted_identifier( )
828+ ) ;
829+ selects. push ( select) ;
830+ }
820831
821- // TODO: allow passing in format binary or text
822- let copy_query = format ! (
823- r#"copy {} ({}) to stdout with (format text);"# ,
824- table_name. as_quoted_identifier( ) ,
825- column_list
826- ) ;
832+ let union_query = selects. join ( " union all " ) ;
833+ format ! ( r#"copy ({}) to stdout with (format text);"# , union_query)
834+ } else {
835+ let table_name = self . get_table_name ( table_id) . await ?;
836+ format ! (
837+ r#"copy {} ({}) to stdout with (format text);"# ,
838+ table_name. as_quoted_identifier( ) ,
839+ column_list
840+ )
841+ } ;
827842
843+ // TODO: allow passing in format binary or text
828844 let stream = self . client . copy_out_simple ( & copy_query) . await ?;
829-
830845 Ok ( stream)
831846 }
832847
@@ -861,4 +876,57 @@ impl PgReplicationClient {
861876 )
862877 } )
863878 }
879+
880+ /// Returns true if the given table id refers to a partitioned table (relkind = 'p').
881+ async fn is_partitioned_table ( & self , table_id : TableId ) -> EtlResult < bool > {
882+ let query = format ! (
883+ "select c.relkind from pg_class c where c.oid = {}" ,
884+ table_id
885+ ) ;
886+
887+ for msg in self . client . simple_query ( & query) . await ? {
888+ if let SimpleQueryMessage :: Row ( row) = msg {
889+ let relkind = Self :: get_row_value :: < String > ( & row, "relkind" , "pg_class" ) . await ?;
890+ return Ok ( relkind == "p" ) ;
891+ }
892+ }
893+
894+ bail ! (
895+ ErrorKind :: SourceSchemaError ,
896+ "Table not found" ,
897+ format!( "Table not found in database (table id: {})" , table_id)
898+ ) ;
899+ }
900+
901+ /// Returns all leaf partition OIDs for a partitioned table.
902+ async fn get_leaf_partition_ids ( & self , parent_id : TableId ) -> EtlResult < Vec < TableId > > {
903+ let query = format ! (
904+ r#"
905+ with recursive parts(relid) as (
906+ select i.inhrelid
907+ from pg_inherits i
908+ where i.inhparent = {parent}
909+ union all
910+ select i.inhrelid
911+ from pg_inherits i
912+ join parts p on p.relid = i.inhparent
913+ )
914+ select p.relid as oid
915+ from parts p
916+ left join pg_inherits i on i.inhparent = p.relid
917+ where i.inhrelid is null
918+ "# ,
919+ parent = parent_id
920+ ) ;
921+
922+ let mut ids = Vec :: new ( ) ;
923+ for msg in self . client . simple_query ( & query) . await ? {
924+ if let SimpleQueryMessage :: Row ( row) = msg {
925+ let id = Self :: get_row_value :: < TableId > ( & row, "oid" , "pg_inherits" ) . await ?;
926+ ids. push ( id) ;
927+ }
928+ }
929+
930+ Ok ( ids)
931+ }
864932}
0 commit comments