import contextlib
import uuid
from datetime import date, datetime, timedelta
from typing import (
Callable,
ContextManager,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)
import numpy as np
import pandas as pd
import pyarrow
import pyarrow.parquet
from pydantic import StrictStr
from pydantic.typing import Literal
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed
from feast import flags_helper
from feast.data_source import DataSource
from feast.errors import (
BigQueryJobCancelled,
BigQueryJobStillRunning,
FeastProviderLoginError,
InvalidEntityType,
)
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
from feast.infra.offline_stores.offline_store import (
OfflineStore,
RetrievalJob,
RetrievalMetadata,
)
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.registry import Registry
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from ...saved_dataset import SavedDatasetStorage
from ...usage import log_exceptions_and_usage
from .bigquery_source import BigQuerySource, SavedDatasetBigQueryStorage
try:
from google.api_core.exceptions import NotFound
from google.auth.exceptions import DefaultCredentialsError
from google.cloud import bigquery
from google.cloud.bigquery import Client, Table
except ImportError as e:
from feast.errors import FeastExtrasDependencyImportError
raise FeastExtrasDependencyImportError("gcp", str(e))
[docs]class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
"""Offline store config for GCP BigQuery"""
type: Literal["bigquery"] = "bigquery"
""" Offline store type selector"""
dataset: StrictStr = "feast"
""" (optional) BigQuery Dataset name for temporary tables """
project_id: Optional[StrictStr] = None
""" (optional) GCP project name used for the BigQuery offline store """
location: Optional[StrictStr] = None
""" (optional) GCP location name used for the BigQuery offline store.
Examples of location names include ``US``, ``EU``, ``us-central1``, ``us-west4``.
If a location is not specified, the location defaults to the ``US`` multi-regional location.
For more information on BigQuery data locations see: https://cloud.google.com/bigquery/docs/locations
"""
[docs]class BigQueryOfflineStore(OfflineStore):
[docs] @staticmethod
@log_exceptions_and_usage(offline_store="bigquery")
def pull_latest_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
created_timestamp_column: Optional[str],
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(data_source, BigQuerySource)
from_expression = data_source.get_table_query_string()
partition_by_join_key_string = ", ".join(join_key_columns)
if partition_by_join_key_string != "":
partition_by_join_key_string = (
"PARTITION BY " + partition_by_join_key_string
)
timestamps = [event_timestamp_column]
if created_timestamp_column:
timestamps.append(created_timestamp_column)
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)
client = _get_bigquery_client(
project=config.offline_store.project_id,
location=config.offline_store.location,
)
query = f"""
SELECT
{field_string}
{f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""}
FROM (
SELECT {field_string},
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row
FROM {from_expression}
WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
)
WHERE _feast_row = 1
"""
# When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized
return BigQueryRetrievalJob(
query=query, client=client, config=config, full_feature_names=False,
)
[docs] @staticmethod
@log_exceptions_and_usage(offline_store="bigquery")
def pull_all_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
join_key_columns: List[str],
feature_name_columns: List[str],
event_timestamp_column: str,
start_date: datetime,
end_date: datetime,
) -> RetrievalJob:
assert isinstance(data_source, BigQuerySource)
from_expression = data_source.get_table_query_string()
client = _get_bigquery_client(
project=config.offline_store.project_id,
location=config.offline_store.location,
)
field_string = ", ".join(
join_key_columns + feature_name_columns + [event_timestamp_column]
)
query = f"""
SELECT {field_string}
FROM {from_expression}
WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
"""
return BigQueryRetrievalJob(
query=query, client=client, config=config, full_feature_names=False,
)
[docs] @staticmethod
@log_exceptions_and_usage(offline_store="bigquery")
def get_historical_features(
config: RepoConfig,
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pd.DataFrame, str],
registry: Registry,
project: str,
full_feature_names: bool = False,
) -> RetrievalJob:
# TODO: Add entity_df validation in order to fail before interacting with BigQuery
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
client = _get_bigquery_client(
project=config.offline_store.project_id,
location=config.offline_store.location,
)
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
table_reference = _get_table_reference_for_new_entity(
client,
client.project,
config.offline_store.dataset,
config.offline_store.location,
)
entity_schema = _get_entity_schema(client=client, entity_df=entity_df,)
entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
entity_schema
)
entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df, entity_df_event_timestamp_col, client,
)
@contextlib.contextmanager
def query_generator() -> Iterator[str]:
_upload_entity_df(
client=client, table_name=table_reference, entity_df=entity_df,
)
expected_join_keys = offline_utils.get_expected_join_keys(
project, feature_views, registry
)
offline_utils.assert_expected_columns_in_entity_df(
entity_schema, expected_join_keys, entity_df_event_timestamp_col
)
# Build a query context containing all information required to template the BigQuery SQL query
query_context = offline_utils.get_feature_view_query_context(
feature_refs,
feature_views,
registry,
project,
entity_df_event_timestamp_range,
)
# Generate the BigQuery SQL query from the query context
query = offline_utils.build_point_in_time_query(
query_context,
left_table_query_string=table_reference,
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
entity_df_columns=entity_schema.keys(),
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
full_feature_names=full_feature_names,
)
try:
yield query
finally:
# Asynchronously clean up the uploaded Bigquery table, which will expire
# if cleanup fails
client.delete_table(table=table_reference, not_found_ok=True)
return BigQueryRetrievalJob(
query=query_generator,
client=client,
config=config,
full_feature_names=full_feature_names,
on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
feature_refs, project, registry
),
metadata=RetrievalMetadata(
features=feature_refs,
keys=list(entity_schema.keys() - {entity_df_event_timestamp_col}),
min_event_timestamp=entity_df_event_timestamp_range[0],
max_event_timestamp=entity_df_event_timestamp_range[1],
),
)
[docs]class BigQueryRetrievalJob(RetrievalJob):
def __init__(
self,
query: Union[str, Callable[[], ContextManager[str]]],
client: bigquery.Client,
config: RepoConfig,
full_feature_names: bool,
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
metadata: Optional[RetrievalMetadata] = None,
):
if not isinstance(query, str):
self._query_generator = query
else:
@contextlib.contextmanager
def query_generator() -> Iterator[str]:
assert isinstance(query, str)
yield query
self._query_generator = query_generator
self.client = client
self.config = config
self._full_feature_names = full_feature_names
self._on_demand_feature_views = (
on_demand_feature_views if on_demand_feature_views else []
)
self._metadata = metadata
@property
def full_feature_names(self) -> bool:
return self._full_feature_names
@property
def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]:
return self._on_demand_feature_views
def _to_df_internal(self) -> pd.DataFrame:
with self._query_generator() as query:
df = self._execute_query(query).to_dataframe(create_bqstorage_client=True)
return df
[docs] def to_sql(self) -> str:
"""
Returns the SQL query that will be executed in BigQuery to build the historical feature table.
"""
with self._query_generator() as query:
return query
[docs] def to_bigquery(
self,
job_config: bigquery.QueryJobConfig = None,
timeout: int = 1800,
retry_cadence: int = 10,
) -> Optional[str]:
"""
Triggers the execution of a historical feature retrieval query and exports the results to a BigQuery table.
Runs for a maximum amount of time specified by the timeout parameter (defaulting to 30 minutes).
Args:
job_config: An optional bigquery.QueryJobConfig to specify options like destination table, dry run, etc.
timeout: An optional number of seconds for setting the time limit of the QueryJob.
retry_cadence: An optional number of seconds for setting how long the job should checked for completion.
Returns:
Returns the destination table name or returns None if job_config.dry_run is True.
"""
if not job_config:
today = date.today().strftime("%Y%m%d")
rand_id = str(uuid.uuid4())[:7]
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
job_config = bigquery.QueryJobConfig(destination=path)
if not job_config.dry_run and self.on_demand_feature_views:
job = self.client.load_table_from_dataframe(
self.to_df(), job_config.destination
)
job.result()
print(f"Done writing to '{job_config.destination}'.")
return str(job_config.destination)
with self._query_generator() as query:
self._execute_query(query, job_config, timeout)
print(f"Done writing to '{job_config.destination}'.")
return str(job_config.destination)
def _to_arrow_internal(self) -> pyarrow.Table:
with self._query_generator() as query:
q = self._execute_query(query=query)
assert q
return q.to_arrow()
@log_exceptions_and_usage
def _execute_query(
self, query, job_config=None, timeout: int = 1800
) -> Optional[bigquery.job.query.QueryJob]:
bq_job = self.client.query(query, job_config=job_config)
if job_config and job_config.dry_run:
print(
"This query will process {} bytes.".format(bq_job.total_bytes_processed)
)
return None
block_until_done(client=self.client, bq_job=bq_job, timeout=timeout)
return bq_job
[docs] def persist(self, storage: SavedDatasetStorage):
assert isinstance(storage, SavedDatasetBigQueryStorage)
self.to_bigquery(
bigquery.QueryJobConfig(destination=storage.bigquery_options.table)
)
@property
def metadata(self) -> Optional[RetrievalMetadata]:
return self._metadata
[docs]def block_until_done(
client: Client,
bq_job: Union[bigquery.job.query.QueryJob, bigquery.job.load.LoadJob],
timeout: int = 1800,
retry_cadence: float = 1,
):
"""
Waits for bq_job to finish running, up to a maximum amount of time specified by the timeout parameter (defaulting to 30 minutes).
Args:
client: A bigquery.client.Client to monitor the bq_job.
bq_job: The bigquery.job.QueryJob that blocks until done runnning.
timeout: An optional number of seconds for setting the time limit of the job.
retry_cadence: An optional number of seconds for setting how long the job should checked for completion.
Raises:
BigQueryJobStillRunning exception if the function has blocked longer than 30 minutes.
BigQueryJobCancelled exception to signify when that the job has been cancelled (i.e. from timeout or KeyboardInterrupt).
"""
# For test environments, retry more aggressively
if flags_helper.is_test():
retry_cadence = 0.1
def _wait_until_done(bq_job):
if client.get_job(bq_job).state in ["PENDING", "RUNNING"]:
raise BigQueryJobStillRunning(job_id=bq_job.job_id)
try:
retryer = Retrying(
wait=wait_fixed(retry_cadence),
stop=stop_after_delay(timeout),
retry=retry_if_exception_type(BigQueryJobStillRunning),
reraise=True,
)
retryer(_wait_until_done, bq_job)
finally:
if client.get_job(bq_job).state in ["PENDING", "RUNNING"]:
client.cancel_job(bq_job)
raise BigQueryJobCancelled(job_id=bq_job.job_id)
if bq_job.exception():
raise bq_job.exception()
def _get_table_reference_for_new_entity(
client: Client,
dataset_project: str,
dataset_name: str,
dataset_location: Optional[str],
) -> str:
"""Gets the table_id for the new entity to be uploaded."""
# First create the BigQuery dataset if it doesn't exist
dataset = bigquery.Dataset(f"{dataset_project}.{dataset_name}")
dataset.location = dataset_location if dataset_location else "US"
try:
client.get_dataset(dataset.reference)
except NotFound:
# Only create the dataset if it does not exist
client.create_dataset(dataset, exists_ok=True)
table_name = offline_utils.get_temp_entity_table_name()
return f"{dataset_project}.{dataset_name}.{table_name}"
def _upload_entity_df(
client: Client, table_name: str, entity_df: Union[pd.DataFrame, str],
) -> Table:
"""Uploads a Pandas entity dataframe into a BigQuery table and returns the resulting table"""
if isinstance(entity_df, str):
job = client.query(f"CREATE TABLE {table_name} AS ({entity_df})")
elif isinstance(entity_df, pd.DataFrame):
# Drop the index so that we don't have unnecessary columns
entity_df.reset_index(drop=True, inplace=True)
job = client.load_table_from_dataframe(entity_df, table_name)
else:
raise InvalidEntityType(type(entity_df))
block_until_done(client, job)
# Ensure that the table expires after some time
table = client.get_table(table=table_name)
table.expires = datetime.utcnow() + timedelta(minutes=30)
client.update_table(table, ["expires"])
return table
def _get_entity_schema(
client: Client, entity_df: Union[pd.DataFrame, str]
) -> Dict[str, np.dtype]:
if isinstance(entity_df, str):
entity_df_sample = (
client.query(f"SELECT * FROM ({entity_df}) LIMIT 1").result().to_dataframe()
)
entity_schema = dict(zip(entity_df_sample.columns, entity_df_sample.dtypes))
elif isinstance(entity_df, pd.DataFrame):
entity_schema = dict(zip(entity_df.columns, entity_df.dtypes))
else:
raise InvalidEntityType(type(entity_df))
return entity_schema
def _get_entity_df_event_timestamp_range(
entity_df: Union[pd.DataFrame, str],
entity_df_event_timestamp_col: str,
client: Client,
) -> Tuple[datetime, datetime]:
if type(entity_df) is str:
job = client.query(
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max "
f"FROM ({entity_df})"
)
res = next(job.result())
entity_df_event_timestamp_range = (
res.get("min"),
res.get("max"),
)
elif isinstance(entity_df, pd.DataFrame):
entity_df_event_timestamp = entity_df.loc[
:, entity_df_event_timestamp_col
].infer_objects()
if pd.api.types.is_string_dtype(entity_df_event_timestamp):
entity_df_event_timestamp = pd.to_datetime(
entity_df_event_timestamp, utc=True
)
entity_df_event_timestamp_range = (
entity_df_event_timestamp.min().to_pydatetime(),
entity_df_event_timestamp.max().to_pydatetime(),
)
else:
raise InvalidEntityType(type(entity_df))
return entity_df_event_timestamp_range
def _get_bigquery_client(project: Optional[str] = None, location: Optional[str] = None):
try:
client = bigquery.Client(project=project, location=location)
except DefaultCredentialsError as e:
raise FeastProviderLoginError(
str(e)
+ '\nIt may be necessary to run "gcloud auth application-default login" if you would like to use your '
"local Google Cloud account"
)
except EnvironmentError as e:
raise FeastProviderLoginError(
"GCP error: "
+ str(e)
+ "\nIt may be necessary to set a default GCP project by running "
'"gcloud config set project your-project"'
)
return client
# TODO: Optimizations
# * Use GENERATE_UUID() instead of ROW_NUMBER(), or join on entity columns directly
# * Precompute ROW_NUMBER() so that it doesn't have to be recomputed for every query on entity_dataframe
# * Create temporary tables instead of keeping all tables in memory
# Note: Keep this in sync with sdk/python/feast/infra/offline_stores/redshift.py:MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN
MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """
/*
Compute a deterministic hash for the `left_table_query_string` that will be used throughout
all the logic as the field to GROUP BY the data
*/
WITH entity_dataframe AS (
SELECT *,
{{entity_df_event_timestamp_col}} AS entity_timestamp
{% for featureview in featureviews %}
{% if featureview.entities %}
,CONCAT(
{% for entity in featureview.entities %}
CAST({{entity}} AS STRING),
{% endfor %}
CAST({{entity_df_event_timestamp_col}} AS STRING)
) AS {{featureview.name}}__entity_row_unique_id
{% else %}
,CAST({{entity_df_event_timestamp_col}} AS STRING) AS {{featureview.name}}__entity_row_unique_id
{% endif %}
{% endfor %}
FROM `{{ left_table_query_string }}`
),
{% for featureview in featureviews %}
{{ featureview.name }}__entity_dataframe AS (
SELECT
{{ featureview.entities | join(', ')}}{% if featureview.entities %},{% else %}{% endif %}
entity_timestamp,
{{featureview.name}}__entity_row_unique_id
FROM entity_dataframe
GROUP BY
{{ featureview.entities | join(', ')}}{% if featureview.entities %},{% else %}{% endif %}
entity_timestamp,
{{featureview.name}}__entity_row_unique_id
),
/*
This query template performs the point-in-time correctness join for a single feature set table
to the provided entity table.
1. We first join the current feature_view to the entity dataframe that has been passed.
This JOIN has the following logic:
- For each row of the entity dataframe, only keep the rows where the `event_timestamp_column`
is less than the one provided in the entity dataframe
- If there a TTL for the current feature_view, also keep the rows where the `event_timestamp_column`
is higher the the one provided minus the TTL
- For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been
computed previously
The output of this CTE will contain all the necessary information and already filtered out most
of the data that is not relevant.
*/
{{ featureview.name }}__subquery AS (
SELECT
{{ featureview.event_timestamp_column }} as event_timestamp,
{{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }}
{{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %}
{% for feature in featureview.features %}
{{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %}{% if loop.last %}{% else %}, {% endif %}
{% endfor %}
FROM {{ featureview.table_subquery }}
WHERE {{ featureview.event_timestamp_column }} <= '{{ featureview.max_event_timestamp }}'
{% if featureview.ttl == 0 %}{% else %}
AND {{ featureview.event_timestamp_column }} >= '{{ featureview.min_event_timestamp }}'
{% endif %}
),
{{ featureview.name }}__base AS (
SELECT
subquery.*,
entity_dataframe.entity_timestamp,
entity_dataframe.{{featureview.name}}__entity_row_unique_id
FROM {{ featureview.name }}__subquery AS subquery
INNER JOIN {{ featureview.name }}__entity_dataframe AS entity_dataframe
ON TRUE
AND subquery.event_timestamp <= entity_dataframe.entity_timestamp
{% if featureview.ttl == 0 %}{% else %}
AND subquery.event_timestamp >= Timestamp_sub(entity_dataframe.entity_timestamp, interval {{ featureview.ttl }} second)
{% endif %}
{% for entity in featureview.entities %}
AND subquery.{{ entity }} = entity_dataframe.{{ entity }}
{% endfor %}
),
/*
2. If the `created_timestamp_column` has been set, we need to
deduplicate the data first. This is done by calculating the
`MAX(created_at_timestamp)` for each event_timestamp.
We then join the data on the next CTE
*/
{% if featureview.created_timestamp_column %}
{{ featureview.name }}__dedup AS (
SELECT
{{featureview.name}}__entity_row_unique_id,
event_timestamp,
MAX(created_timestamp) as created_timestamp
FROM {{ featureview.name }}__base
GROUP BY {{featureview.name}}__entity_row_unique_id, event_timestamp
),
{% endif %}
/*
3. The data has been filtered during the first CTE "*__base"
Thus we only need to compute the latest timestamp of each feature.
*/
{{ featureview.name }}__latest AS (
SELECT
event_timestamp,
{% if featureview.created_timestamp_column %}created_timestamp,{% endif %}
{{featureview.name}}__entity_row_unique_id
FROM
(
SELECT *,
ROW_NUMBER() OVER(
PARTITION BY {{featureview.name}}__entity_row_unique_id
ORDER BY event_timestamp DESC{% if featureview.created_timestamp_column %},created_timestamp DESC{% endif %}
) AS row_number
FROM {{ featureview.name }}__base
{% if featureview.created_timestamp_column %}
INNER JOIN {{ featureview.name }}__dedup
USING ({{featureview.name}}__entity_row_unique_id, event_timestamp, created_timestamp)
{% endif %}
)
WHERE row_number = 1
),
/*
4. Once we know the latest value of each feature for a given timestamp,
we can join again the data back to the original "base" dataset
*/
{{ featureview.name }}__cleaned AS (
SELECT base.*
FROM {{ featureview.name }}__base as base
INNER JOIN {{ featureview.name }}__latest
USING(
{{featureview.name}}__entity_row_unique_id,
event_timestamp
{% if featureview.created_timestamp_column %}
,created_timestamp
{% endif %}
)
){% if loop.last %}{% else %}, {% endif %}
{% endfor %}
/*
Joins the outputs of multiple time travel joins to a single table.
The entity_dataframe dataset being our source of truth here.
*/
SELECT {{ final_output_feature_names | join(', ')}}
FROM entity_dataframe
{% for featureview in featureviews %}
LEFT JOIN (
SELECT
{{featureview.name}}__entity_row_unique_id
{% for feature in featureview.features %}
,{% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %}
{% endfor %}
FROM {{ featureview.name }}__cleaned
) USING ({{featureview.name}}__entity_row_unique_id)
{% endfor %}
"""