2222import attr
2323import pandas as pd
2424
25+ from sagemaker import Session
2526from sagemaker .feature_store .feature_group import FeatureGroup
2627
2728
@@ -33,6 +34,7 @@ class DatasetBuilder:
3334 an output path and a KMS key ID.
3435
3536 Attributes:
37+ _sagemaker_session (Session): Session instance to perform boto calls.
3638 _base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a
3739 pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset.
3840 _output_path (str): An S3 URI which stores the output .csv file.
@@ -59,6 +61,7 @@ class DatasetBuilder:
5961 dataset will be before it.
6062 """
6163
64+ _sagemaker_session : Session = attr .ib ()
6265 _base : Union [FeatureGroup , pd .DataFrame ] = attr .ib ()
6366 _output_path : str = attr .ib ()
6467 _record_identifier_feature_name : str = attr .ib (default = None )
@@ -155,3 +158,104 @@ def with_event_time_range(
155158 self ._event_time_starting_timestamp = starting_timestamp
156159 self ._event_time_ending_timestamp = ending_timestamp
157160 return self
161+
162+ def to_csv (self ):
163+ """Get query string and result in .csv format
164+
165+ Returns:
166+ The S3 path of the .csv file.
167+ The query string executed.
168+ """
169+ if isinstance (self ._base , FeatureGroup ):
170+ # TODO: handle pagination and input feature validation
171+ base_feature_group = self ._base .describe ()
172+ data_catalog_config = base_feature_group .get ("OfflineStoreConfig" , None ).get (
173+ "DataCatalogConfig" , None
174+ )
175+ if not data_catalog_config :
176+ raise RuntimeError ("No metastore is configured with the base FeatureGroup." )
177+ disable_glue = base_feature_group .get ("DisableGlueTableCreation" , False )
178+ self ._record_identifier_feature_name = base_feature_group .get (
179+ "RecordIdentifierFeatureName" , None
180+ )
181+ self ._event_time_identifier_feature_name = base_feature_group .get (
182+ "EventTimeFeatureName" , None
183+ )
184+ base_features = [
185+ feature .get ("FeatureName" , None )
186+ for feature in base_feature_group .get ("FeatureDefinitions" , None )
187+ ]
188+
189+ query = self ._sagemaker_session .start_query_execution (
190+ catalog = data_catalog_config .get ("Catalog" , None )
191+ if disable_glue
192+ else "AwsDataCatalog" ,
193+ database = data_catalog_config .get ("Database" , None ),
194+ query_string = self ._construct_query_string (
195+ data_catalog_config .get ("TableName" , None ),
196+ data_catalog_config .get ("Database" , None ),
197+ base_features ,
198+ ),
199+ output_location = self ._output_path ,
200+ kms_key = self ._kms_key_id ,
201+ )
202+ query_id = query .get ("QueryExecutionId" , None )
203+ self ._sagemaker_session .wait_for_athena_query (
204+ query_execution_id = query_id ,
205+ )
206+ query_state = (
207+ self ._sagemaker_session .get_query_execution (
208+ query_execution_id = query_id ,
209+ )
210+ .get ("QueryExecution" , None )
211+ .get ("Status" , None )
212+ .get ("State" , None )
213+ )
214+ if query_state != "SUCCEEDED" :
215+ raise RuntimeError (f"Failed to execute query { query_id } ." )
216+
217+ return query_state .get ("QueryExecution" , None ).get ("ResultConfiguration" , None ).get (
218+ "OutputLocation" , None
219+ ), query_state .get ("QueryExecution" , None ).get ("Query" , None )
220+ raise ValueError ("Base must be either a FeatureGroup or a DataFrame." )
221+
222+ def _construct_query_string (
223+ self , base_table_name : str , database : str , base_features : list
224+ ) -> str :
225+ """Internal method for constructing SQL query string by parameters.
226+
227+ Args:
228+ base_table_name (str): The Athena table name of base FeatureGroup or pandas.DataFrame.
229+ database (str): The Athena database of the base table.
230+ base_features (list): The list of features of the base table.
231+ Returns:
232+ The query string.
233+ """
234+ included_features = ", " .join (
235+ [
236+ f'base."{ include_feature_name } "'
237+ for include_feature_name in self ._included_feature_names
238+ ]
239+ )
240+ query_string = f"SELECT { included_features } \n "
241+ if self ._include_duplicated_records :
242+ query_string += f'FROM "{ database } "."{ base_table_name } " base\n '
243+ if not self ._include_deleted_records :
244+ query_string += "WHERE NOT is_deleted\n "
245+ else :
246+ base_features .remove (self ._event_time_identifier_feature_name )
247+ dedup_features = ", " .join ([f'dedup_base."{ feature } "' for feature in base_features ])
248+ query_string += (
249+ "FROM (\n "
250+ + "SELECT *, row_number() OVER (\n "
251+ + f"PARTITION BY { dedup_features } \n "
252+ + f'ORDER BY dedup_base."{ self ._event_time_identifier_feature_name } " '
253+ + 'DESC, dedup_base."api_invocation_time" DESC, dedup_base."write_time" DESC\n '
254+ + ") AS row_base\n "
255+ + f'FROM "{ database } "."{ base_table_name } " dedup_base\n '
256+ + ") AS base\n "
257+ + "WHERE row_base = 1\n "
258+ )
259+ if not self ._include_deleted_records :
260+ query_string += "AND NOT is_deleted\n "
261+ return query_string
0 commit comments