diff --git a/examples/keras_rs/img/two_stage_rs_with_marketing_interaction/architecture.jpg b/examples/keras_rs/img/two_stage_rs_with_marketing_interaction/architecture.jpg
new file mode 100644
index 0000000000..05e81acfa3
Binary files /dev/null and b/examples/keras_rs/img/two_stage_rs_with_marketing_interaction/architecture.jpg differ
diff --git a/examples/keras_rs/img/two_stage_rs_with_marketing_interaction/two_stage_rs_with_marketing_interaction_13_60.png b/examples/keras_rs/img/two_stage_rs_with_marketing_interaction/two_stage_rs_with_marketing_interaction_13_60.png
new file mode 100644
index 0000000000..af9143ab07
Binary files /dev/null and b/examples/keras_rs/img/two_stage_rs_with_marketing_interaction/two_stage_rs_with_marketing_interaction_13_60.png differ
diff --git a/examples/keras_rs/img/two_stage_rs_with_marketing_interaction/two_stage_rs_with_marketing_interaction_9_90.png b/examples/keras_rs/img/two_stage_rs_with_marketing_interaction/two_stage_rs_with_marketing_interaction_9_90.png
new file mode 100644
index 0000000000..3f99f3b6b5
Binary files /dev/null and b/examples/keras_rs/img/two_stage_rs_with_marketing_interaction/two_stage_rs_with_marketing_interaction_9_90.png differ
diff --git a/examples/keras_rs/ipynb/two_stage_rs_with_marketing_interaction.ipynb b/examples/keras_rs/ipynb/two_stage_rs_with_marketing_interaction.ipynb
new file mode 100644
index 0000000000..2ef7dee801
--- /dev/null
+++ b/examples/keras_rs/ipynb/two_stage_rs_with_marketing_interaction.ipynb
@@ -0,0 +1,694 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Two Stage Recommender System with Marketing Interaction\n",
+ "\n",
+ "**Author:** Mansi Mehta
\n",
+ "**Date created:** 26/11/2025
\n",
+ "**Last modified:** 06/12/2025
\n",
+ "**Description:** Recommender System with Ranking and Retrieval model for Marketing interaction."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# **Introduction**\n",
+ "\n",
+ "This tutorial demonstrates a critical business scenario: a user lands on a website, and a\n",
+ "marketing engine must decide which specific ad to display from an inventory of thousands.\n",
+ "The goal is to maximize the Click-Through Rate (CTR). Showing irrelevant ads wastes\n",
+ "marketing budget and annoys the user. Therefore, we need a system that predicts the\n",
+ "probability of a specific user clicking on a specific ad based on their demographics and\n",
+ "browsing habits.\n",
+ "\n",
+ "**Architecture**\n",
+ "1. **The Retrieval Stage:** Efficiently select an initial set of roughly 10-100\n",
+ "candidates from millions of possibilities. It weeds out items the user is definitely not\n",
+ "interested in.\n",
+ "User Tower: Embeds user features (ID, demographics, behavior) into a vector.\n",
+ "Item Tower: Embeds ad features (Ad ID, Topic) into a vector.\n",
+ "Interaction: The dot product of these two vectors represents similarity.\n",
+ "2. **The Ranking Stage:** It takes the output of the retrieval model and fine-tune the\n",
+ "order to select the single best ad to show.\n",
+ "A Deep Neural Network (MLP).\n",
+ "Interaction: It takes the User Embedding, Ad Embedding, and their similarity score to\n",
+ "predict a precise probability (0% to 100%) that the user will click.\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# **Dataset**\n",
+ "We will use the [Ad Click\n",
+ "Prediction](https://www.kaggle.com/datasets/mafrojaakter/ad-click-data) Dataset from\n",
+ "Kaggle\n",
+ "\n",
+ "**Feature Distribution of dataset:**\n",
+ "User Tower describes who is looking and features contains i.e Gender, City, Country, Age,\n",
+ "Daily Internet Usage, Daily Time Spent on Site, and Area Income.\n",
+ "Item Tower describes what is being shown and features contains Ad Topic Line, Ad ID.\n",
+ "\n",
+ "In this tutorial, we are going to build and train a Two-Tower (User Tower and Ad Tower)\n",
+ "model using the Ad Click Prediction dataset from Kaggle.\n",
+ "We're going to:\n",
+ "1. **Data Pipeline:** Get our data and preprocess it for both Retrieval (implicit\n",
+ "feedback) and Ranking (explicit labels).\n",
+ "2. **Retrieval:** Implement and train a Two-Tower model to generate candidates.\n",
+ "3. **Ranking:** Implement and train a Neural Ranking model to predict click probabilities.\n",
+ "4. **Inference:** Run an end-to-end test (Retrieval --> Ranking) to generate\n",
+ "recommendations for a specific user."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q keras-rs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
+ "import keras\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import tensorflow as tf\n",
+ "import pandas as pd\n",
+ "import keras_rs\n",
+ "\n",
+ "from keras import layers\n",
+ "from concurrent.futures import ThreadPoolExecutor\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.preprocessing import MinMaxScaler\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# **Preparing Dataset**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q kaggle\n",
+ "!kaggle datasets download -d mafrojaakter/ad-click-data --unzip -p ./ad_click_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "data_path = \"./ad_click_dataset/Ad Click Data.csv\"\n",
+ "if not os.path.exists(data_path):\n",
+ " # Fallback for filenames with spaces or different casing\n",
+ " data_path = \"./ad_click_dataset/Ad Click Data.csv\"\n",
+ "\n",
+ "ads_df = pd.read_csv(data_path)\n",
+ "# Clean column names\n",
+ "ads_df.columns = ads_df.columns.str.strip()\n",
+ "# Rename the column name\n",
+ "ads_df = ads_df.rename(\n",
+ " columns={\n",
+ " \"Male\": \"gender\",\n",
+ " \"Ad Topic Line\": \"ad_topic\",\n",
+ " \"City\": \"city\",\n",
+ " \"Country\": \"country\",\n",
+ " \"Daily Time Spent on Site\": \"time_on_site\",\n",
+ " \"Daily Internet Usage\": \"internet_usage\",\n",
+ " \"Area Income\": \"area_income\",\n",
+ " }\n",
+ ")\n",
+ "# Add user_id and add_id column\n",
+ "ads_df[\"user_id\"] = \"user_\" + ads_df.index.astype(str)\n",
+ "ads_df[\"ad_id\"] = \"ad_\" + ads_df[\"ad_topic\"].astype(\"category\").cat.codes.astype(str)\n",
+ "# Remove nulls and normalize\n",
+ "ads_df = ads_df.dropna()\n",
+ "# normalize\n",
+ "numeric_cols = [\"time_on_site\", \"internet_usage\", \"area_income\", \"Age\"]\n",
+ "scaler = MinMaxScaler()\n",
+ "ads_df[numeric_cols] = scaler.fit_transform(ads_df[numeric_cols])\n",
+ "\n",
+ "# Split the train and test datasets\n",
+ "x_train, x_test = train_test_split(ads_df, test_size=0.2, random_state=42)\n",
+ "\n",
+ "\n",
+ "def dict_to_tensor_features(df_features, continuous_features):\n",
+ " tensor_dict = {}\n",
+ " for k, v in df_features.items():\n",
+ " if k in continuous_features:\n",
+ " tensor_dict[k] = tf.expand_dims(tf.constant(v, dtype=\"float32\"), axis=-1)\n",
+ " else:\n",
+ " v_str = np.array(v).astype(str).tolist()\n",
+ " tensor_dict[k] = tf.expand_dims(tf.constant(v_str, dtype=\"string\"), axis=-1)\n",
+ " return tensor_dict\n",
+ "\n",
+ "\n",
+ "def create_retrieval_dataset(\n",
+ " data_df,\n",
+ " all_ads_features,\n",
+ " all_ad_ids,\n",
+ " user_features_list,\n",
+ " ad_features_list,\n",
+ " continuous_features_list,\n",
+ "):\n",
+ "\n",
+ " # Filter for Positive Interactions (Clicks)\n",
+ " positive_interactions = data_df[data_df[\"Clicked on Ad\"] == 1].copy()\n",
+ "\n",
+ " if positive_interactions.empty:\n",
+ " return None\n",
+ "\n",
+ " def sample_negative(positive_ad_id):\n",
+ " all_ad_ids_filtered = [aid for aid in all_ad_ids if aid != positive_ad_id]\n",
+ " if not all_ad_ids_filtered:\n",
+ " return positive_ad_id\n",
+ " neg_ad_id = np.random.choice(all_ad_ids_filtered)\n",
+ " return neg_ad_id\n",
+ "\n",
+ " def create_triplets_row(pos_row):\n",
+ " pos_ad_id = pos_row.ad_id\n",
+ " neg_ad_id = sample_negative(pos_ad_id)\n",
+ "\n",
+ " neg_ad_row = all_ads_features[all_ads_features[\"ad_id\"] == neg_ad_id].iloc[0]\n",
+ " user_features_dict = {\n",
+ " name: getattr(pos_row, name) for name in user_features_list\n",
+ " }\n",
+ " pos_ad_features_dict = {\n",
+ " name: getattr(pos_row, name) for name in ad_features_list\n",
+ " }\n",
+ " neg_ad_features_dict = {name: neg_ad_row[name] for name in ad_features_list}\n",
+ "\n",
+ " return {\n",
+ " \"user\": user_features_dict,\n",
+ " \"positive_ad\": pos_ad_features_dict,\n",
+ " \"negative_ad\": neg_ad_features_dict,\n",
+ " }\n",
+ "\n",
+ " with ThreadPoolExecutor(max_workers=os.cpu_count() or 8) as executor:\n",
+ " triplets = list(\n",
+ " executor.map(\n",
+ " create_triplets_row, positive_interactions.itertuples(index=False)\n",
+ " )\n",
+ " )\n",
+ "\n",
+ " triplets_df = pd.DataFrame(triplets)\n",
+ " user_df = triplets_df[\"user\"].apply(pd.Series)\n",
+ " pos_ad_df = triplets_df[\"positive_ad\"].apply(pd.Series)\n",
+ " neg_ad_df = triplets_df[\"negative_ad\"].apply(pd.Series)\n",
+ "\n",
+ " user_features_tensor = dict_to_tensor_features(\n",
+ " user_df.to_dict(\"list\"), continuous_features_list\n",
+ " )\n",
+ " pos_ad_features_tensor = dict_to_tensor_features(\n",
+ " pos_ad_df.to_dict(\"list\"), continuous_features_list\n",
+ " )\n",
+ " neg_ad_features_tensor = dict_to_tensor_features(\n",
+ " neg_ad_df.to_dict(\"list\"), continuous_features_list\n",
+ " )\n",
+ "\n",
+ " features = {\n",
+ " \"user\": user_features_tensor,\n",
+ " \"positive_ad\": pos_ad_features_tensor,\n",
+ " \"negative_ad\": neg_ad_features_tensor,\n",
+ " }\n",
+ " y_true = tf.ones((triplets_df.shape[0], 1), dtype=tf.float32)\n",
+ " dataset = tf.data.Dataset.from_tensor_slices((features, y_true))\n",
+ " buffer_size = len(triplets_df)\n",
+ " dataset = (\n",
+ " dataset.shuffle(buffer_size=buffer_size)\n",
+ " .batch(64)\n",
+ " .cache()\n",
+ " .prefetch(tf.data.AUTOTUNE)\n",
+ " )\n",
+ " return dataset\n",
+ "\n",
+ "\n",
+ "user_clicked_ads = (\n",
+ " x_train[x_train[\"Clicked on Ad\"] == 1]\n",
+ " .groupby(\"user_id\")[\"ad_id\"]\n",
+ " .apply(set)\n",
+ " .to_dict()\n",
+ ")\n",
+ "\n",
+ "for u in x_train[\"user_id\"].unique():\n",
+ " if u not in user_clicked_ads:\n",
+ " user_clicked_ads[u] = set()\n",
+ "\n",
+ "AD_FEATURES = [\"ad_id\", \"ad_topic\"]\n",
+ "USER_FEATURES = [\n",
+ " \"user_id\",\n",
+ " \"gender\",\n",
+ " \"city\",\n",
+ " \"country\",\n",
+ " \"time_on_site\",\n",
+ " \"internet_usage\",\n",
+ " \"area_income\",\n",
+ " \"Age\",\n",
+ "]\n",
+ "continuous_features = [\"time_on_site\", \"internet_usage\", \"area_income\", \"Age\"]\n",
+ "\n",
+ "all_ads_features = x_train[AD_FEATURES].drop_duplicates().reset_index(drop=True)\n",
+ "all_ad_ids = all_ads_features[\"ad_id\"].tolist()\n",
+ "\n",
+ "retrieval_train_dataset = create_retrieval_dataset(\n",
+ " data_df=x_train,\n",
+ " all_ads_features=all_ads_features,\n",
+ " all_ad_ids=all_ad_ids,\n",
+ " user_features_list=USER_FEATURES,\n",
+ " ad_features_list=AD_FEATURES,\n",
+ " continuous_features_list=continuous_features,\n",
+ ")\n",
+ "\n",
+ "retrieval_test_dataset = create_retrieval_dataset(\n",
+ " data_df=x_test,\n",
+ " all_ads_features=all_ads_features,\n",
+ " all_ad_ids=all_ad_ids,\n",
+ " user_features_list=USER_FEATURES,\n",
+ " ad_features_list=AD_FEATURES,\n",
+ " continuous_features_list=continuous_features,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# **Implement the Retrieval Model**\n",
+ "For the Retrieval stage, we will build a Two-Tower Model.\n",
+ "\n",
+ "**The Architecture Components:**\n",
+ "\n",
+ "1. User Tower:User features (User ID, demographics, behavior metrics like time_on_site).\n",
+ "It encodes these mixed features into a fixed-size vector representation called the User\n",
+ "Embedding.\n",
+ "2. Item (Ad) Tower:Ad features (Ad ID, Ad Topic Line).It encodes these features into a\n",
+ "fixed-size vector representation called the Item Embedding.\n",
+ "3. Interaction (Similarity):We calculate the Dot Product between the User Embedding and\n",
+ "the Item Embedding."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "keras.utils.set_random_seed(42)\n",
+ "\n",
+ "vocab_map = {\n",
+ " \"user_id\": x_train[\"user_id\"].unique(),\n",
+ " \"gender\": x_train[\"gender\"].astype(str).unique(),\n",
+ " \"city\": x_train[\"city\"].unique(),\n",
+ " \"country\": x_train[\"country\"].unique(),\n",
+ " \"ad_id\": x_train[\"ad_id\"].unique(),\n",
+ " \"ad_topic\": x_train[\"ad_topic\"].unique(),\n",
+ "}\n",
+ "cont_feats = [\"time_on_site\", \"internet_usage\", \"area_income\", \"Age\"]\n",
+ "\n",
+ "normalizers = {}\n",
+ "for f in cont_feats:\n",
+ " norm = layers.Normalization(axis=None)\n",
+ " norm.adapt(x_train[f].values.astype(\"float32\"))\n",
+ " normalizers[f] = norm\n",
+ "\n",
+ "\n",
+ "def build_tower(feature_names, continuous_names=None, embed_dim=64, name=\"tower\"):\n",
+ " inputs, embeddings = {}, []\n",
+ "\n",
+ " for feat in feature_names:\n",
+ " if feat in vocab_map:\n",
+ " inp = keras.Input(shape=(1,), dtype=tf.string, name=feat)\n",
+ " inputs[feat] = inp\n",
+ " vocab = list(vocab_map[feat])\n",
+ " x = layers.StringLookup(vocabulary=vocab)(inp)\n",
+ " x = layers.Embedding(\n",
+ " len(vocab) + 1, embed_dim, embeddings_regularizer=\"l2\"\n",
+ " )(x)\n",
+ " embeddings.append(layers.Flatten()(x))\n",
+ "\n",
+ " if continuous_names:\n",
+ " for feat in continuous_names:\n",
+ " inp = keras.Input(shape=(1,), dtype=tf.float32, name=feat)\n",
+ " inputs[feat] = inp\n",
+ " embeddings.append(normalizers[feat](inp))\n",
+ "\n",
+ " x = layers.Concatenate()(embeddings)\n",
+ " x = layers.Dense(128, activation=\"relu\")(x)\n",
+ " x = layers.Dropout(0.2)(x)\n",
+ " x = layers.Dense(64, activation=\"relu\")(x)\n",
+ " output = layers.Dense(embed_dim)(layers.Dropout(0.2)(x))\n",
+ "\n",
+ " return keras.Model(inputs=inputs, outputs=output, name=name)\n",
+ "\n",
+ "\n",
+ "user_tower = build_tower(\n",
+ " [\"user_id\", \"gender\", \"city\", \"country\"], cont_feats, name=\"user_tower\"\n",
+ ")\n",
+ "ad_tower = build_tower([\"ad_id\", \"ad_topic\"], name=\"ad_tower\")\n",
+ "\n",
+ "\n",
+ "def pairwise_logistic_loss(y_true, y_pred):\n",
+ " return -tf.math.log(tf.nn.sigmoid(y_pred) + 1e-10)\n",
+ "\n",
+ "\n",
+ "class RetrievalModel(keras.Model):\n",
+ " def __init__(self, user_tower_instance, ad_tower_instance, **kwargs):\n",
+ " super().__init__(**kwargs)\n",
+ " self.user_tower = user_tower_instance\n",
+ " self.ad_tower = ad_tower\n",
+ " self.ln_user = layers.LayerNormalization()\n",
+ " self.ln_ad = layers.LayerNormalization()\n",
+ "\n",
+ " def call(self, inputs):\n",
+ " u_emb = self.ln_user(self.user_tower(inputs[\"user\"]))\n",
+ " pos_emb = self.ln_ad(self.ad_tower(inputs[\"positive_ad\"]))\n",
+ " neg_emb = self.ln_ad(self.ad_tower(inputs[\"negative_ad\"]))\n",
+ " pos_score = keras.ops.sum(u_emb * pos_emb, axis=1, keepdims=True)\n",
+ " neg_score = keras.ops.sum(u_emb * neg_emb, axis=1, keepdims=True)\n",
+ " return pos_score - neg_score\n",
+ "\n",
+ " def get_embeddings(self, inputs):\n",
+ " u_emb = self.ln_user(self.user_tower(inputs[\"user\"]))\n",
+ " ad_emb = self.ln_ad(self.ad_tower(inputs[\"positive_ad\"]))\n",
+ " dot_interaction = keras.ops.sum(u_emb * ad_emb, axis=1, keepdims=True)\n",
+ " return u_emb, ad_emb, dot_interaction\n",
+ "\n",
+ "\n",
+ "retrieval_model = RetrievalModel(user_tower, ad_tower)\n",
+ "retrieval_model.compile(\n",
+ " optimizer=keras.optimizers.Adam(learning_rate=1e-3), loss=pairwise_logistic_loss\n",
+ ")\n",
+ "history = retrieval_model.fit(retrieval_train_dataset, epochs=30)\n",
+ "\n",
+ "pd.DataFrame(history.history).plot(\n",
+ " subplots=True, layout=(1, 3), figsize=(12, 4), title=\"Retrieval Model Metrics\"\n",
+ ")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# **Predictions of Retrieval Model**\n",
+ "We can implement inference pipeline for retrieval Model using three steps:\n",
+ "1. Indexing: We can run the Item Tower once for all available ads to generate their\n",
+ "embeddings.\n",
+ "2. Query Encoding: When a user arrives, we pass their features through the User Tower to\n",
+ "generate a User Embedding.\n",
+ "3. Nearest Neighbor Search: We search the index to find the Ad Embeddings closest to the\n",
+ "User Embedding (highest dot product).\n",
+ "\n",
+ "Keras-RS [BruteForceRetrieval\n",
+ "layer](https://keras.io/keras_rs/api/retrieval_layers/brute_force_retrieval/) calculates\n",
+ "dot product between the user and every single item in the index to find exact top-K\n",
+ "matches"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "USER_CATEGORICAL = [\"user_id\", \"gender\", \"city\", \"country\"]\n",
+ "CONTINUOUS_FEATURES = [\"time_on_site\", \"internet_usage\", \"area_income\", \"Age\"]\n",
+ "USER_FEATURES = USER_CATEGORICAL + CONTINUOUS_FEATURES\n",
+ "\n",
+ "\n",
+ "class BruteForceRetrievalWrapper:\n",
+ " def __init__(self, model, ads_df, ad_features, user_features, k=10):\n",
+ " self.model, self.k = model, k\n",
+ " self.user_features = user_features\n",
+ " unique_ads = ads_df[ad_features].drop_duplicates(\"ad_id\").reset_index(drop=True)\n",
+ " self.ids = unique_ads[\"ad_id\"].values\n",
+ " self.topic_map = dict(zip(unique_ads[\"ad_id\"], unique_ads[\"ad_topic\"]))\n",
+ " ad_inputs = {\n",
+ " \"ad_id\": tf.constant(self.ids.astype(str)),\n",
+ " \"ad_topic\": tf.constant(unique_ads[\"ad_topic\"].astype(str).values),\n",
+ " }\n",
+ " self.candidate_embs = model.ln_ad(model.ad_tower(ad_inputs))\n",
+ "\n",
+ " def query_batch(self, user_df):\n",
+ " inputs = {\n",
+ " k: tf.constant(\n",
+ " user_df[k].values.astype(float if k in CONTINUOUS_FEATURES else str)\n",
+ " )\n",
+ " for k in self.user_features\n",
+ " if k in user_df.columns\n",
+ " }\n",
+ " u_emb = self.model.ln_user(self.model.user_tower(inputs))\n",
+ " scores = tf.linalg.matmul(u_emb, self.candidate_embs, transpose_b=True)\n",
+ " top_scores, top_indices = tf.math.top_k(scores, k=self.k)\n",
+ " return top_scores.numpy(), top_indices.numpy()\n",
+ "\n",
+ " def decode_results(self, scores, indices):\n",
+ " results = []\n",
+ " for row_scores, row_indices in zip(scores, indices):\n",
+ " retrieved_ids = self.ids[row_indices]\n",
+ " results.append(\n",
+ " [\n",
+ " {\"ad_id\": aid, \"ad_topic\": self.topic_map[aid], \"score\": float(s)}\n",
+ " for aid, s in zip(retrieved_ids, row_scores)\n",
+ " ]\n",
+ " )\n",
+ " return results\n",
+ "\n",
+ "\n",
+ "retrieval_engine = BruteForceRetrievalWrapper(\n",
+ " model=retrieval_model,\n",
+ " ads_df=ads_df,\n",
+ " ad_features=[\"ad_id\", \"ad_topic\"],\n",
+ " user_features=USER_FEATURES,\n",
+ " k=10,\n",
+ ")\n",
+ "sample_user = pd.DataFrame([x_test.iloc[0]])\n",
+ "scores, indices = retrieval_engine.query_batch(sample_user)\n",
+ "top_ads = retrieval_engine.decode_results(scores, indices)[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# **Implementation of Ranking Model**\n",
+ "Retrieval model only calculates a simple similarity score (Dot Product). It doesn't\n",
+ "account for complex feature interactions.\n",
+ "So we need to build ranking model after words retrieval model.\n",
+ "\n",
+ "**Architecture**\n",
+ "1. **Feature Extraction:** We reuse the trained User Tower and Ad Tower from the\n",
+ "Retrieval stage. We freeze these towers (trainable = False) so their weights don't\n",
+ "change.\n",
+ "2. **Interaction:** Instead of just a dot product, we concatenate three inputs- The User\n",
+ "EmbeddingThe Ad EmbeddingThe Dot Product (Similarity)\n",
+ "3. **Scorer(MLP):** These concatenated inputs are fed into a Multi-Layer Perceptron—a\n",
+ "stack of Dense layers. This network learns the non-linear relationships between the user\n",
+ "and the ad.\n",
+ "4. **Output:** The final layer uses a Sigmoid activation to output a single probability\n",
+ "between 0.0 and 1.0 (Likelihood of a Click)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "retrieval_model.trainable = False\n",
+ "\n",
+ "\n",
+ "def create_ranking_ds(df):\n",
+ " inputs = {\n",
+ " \"user\": dict_to_tensor_features(df[USER_FEATURES], continuous_features),\n",
+ " \"positive_ad\": dict_to_tensor_features(df[AD_FEATURES], continuous_features),\n",
+ " }\n",
+ " return (\n",
+ " tf.data.Dataset.from_tensor_slices(\n",
+ " (inputs, df[\"Clicked on Ad\"].values.astype(\"float32\"))\n",
+ " )\n",
+ " .shuffle(10000)\n",
+ " .batch(256)\n",
+ " .prefetch(tf.data.AUTOTUNE)\n",
+ " )\n",
+ "\n",
+ "\n",
+ "ranking_train_dataset = create_ranking_ds(x_train)\n",
+ "ranking_test_dataset = create_ranking_ds(x_test)\n",
+ "\n",
+ "\n",
+ "class RankingModel(keras.Model):\n",
+ " def __init__(self, retrieval_model, **kwargs):\n",
+ " super().__init__(**kwargs)\n",
+ " self.retrieval = retrieval_model\n",
+ " self.mlp = keras.Sequential(\n",
+ " [\n",
+ " layers.Dense(256, activation=\"relu\"),\n",
+ " layers.Dropout(0.2),\n",
+ " layers.Dense(128, activation=\"relu\"),\n",
+ " layers.Dropout(0.2),\n",
+ " layers.Dense(64, activation=\"relu\"),\n",
+ " layers.Dense(1, activation=\"sigmoid\"),\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " def call(self, inputs):\n",
+ " u_emb, ad_emb, dot = self.retrieval.get_embeddings(inputs)\n",
+ " return self.mlp(keras.ops.concatenate([u_emb, ad_emb, dot], axis=-1))\n",
+ "\n",
+ "\n",
+ "ranking_model = RankingModel(retrieval_model)\n",
+ "ranking_model.compile(\n",
+ " optimizer=keras.optimizers.Adam(1e-4),\n",
+ " loss=\"binary_crossentropy\",\n",
+ " metrics=[\"AUC\", \"accuracy\"],\n",
+ ")\n",
+ "history1 = ranking_model.fit(ranking_train_dataset, epochs=20)\n",
+ "\n",
+ "pd.DataFrame(history1.history).plot(\n",
+ " subplots=True, layout=(1, 3), figsize=(12, 4), title=\"Ranking Model Metrics\"\n",
+ ")\n",
+ "plt.show()\n",
+ "\n",
+ "ranking_model.evaluate(ranking_test_dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# **Predictions of Ranking Model**\n",
+ "The retrieval model gave us a list of ads that are generally relevant (high dot product\n",
+ "similarity). The ranking model will now calculate the specific probability (0% to 100%)\n",
+ "that the user will click each of those ads.\n",
+ "\n",
+ "The Ranking model expects pairs of (User, Ad). Since we are scoring 10 ads for 1 user, we\n",
+ "cannot just pass the user features once.We effectively take user's features 10 times to\n",
+ "create a batch."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def rerank_ads_for_user(user_row, retrieved_ads, ranking_model):\n",
+ " ads_df = pd.DataFrame(retrieved_ads)\n",
+ " num_ads = len(ads_df)\n",
+ " user_inputs = {\n",
+ " k: tf.fill(\n",
+ " (num_ads, 1),\n",
+ " str(user_row[k]) if k not in continuous_features else float(user_row[k]),\n",
+ " )\n",
+ " for k in USER_FEATURES\n",
+ " }\n",
+ " ad_inputs = {\n",
+ " k: tf.reshape(tf.constant(ads_df[k].astype(str).values), (-1, 1))\n",
+ " for k in AD_FEATURES\n",
+ " }\n",
+ " scores = (\n",
+ " ranking_model({\"user\": user_inputs, \"positive_ad\": ad_inputs}).numpy().flatten()\n",
+ " )\n",
+ " ads_df[\"ranking_score\"] = scores\n",
+ " return ads_df.sort_values(\"ranking_score\", ascending=False).to_dict(\"records\")\n",
+ "\n",
+ "\n",
+ "sample_user = x_test.iloc[0]\n",
+ "scores, indices = retrieval_engine.query_batch(pd.DataFrame([sample_user]))\n",
+ "top_ads = retrieval_engine.decode_results(scores, indices)[0]\n",
+ "final_ranked_ads = rerank_ads_for_user(sample_user, top_ads, ranking_model)\n",
+ "print(f\"User: {sample_user['user_id']}\")\n",
+ "print(f\"{'Ad ID':<10} | {'Topic':<30} | {'Retrieval Score':<11} | {'Rank Probability'}\")\n",
+ "for item in final_ranked_ads:\n",
+ " print(\n",
+ " f\"{item['ad_id']:<10} | {item['ad_topic'][:28]:<30} | {item['score']:.4f} | {item['ranking_score']*100:.2f}%\"\n",
+ " )"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "two_stage_rs_with_marketing_interaction",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/examples/keras_rs/md/two_stage_rs_with_marketing_interaction.md b/examples/keras_rs/md/two_stage_rs_with_marketing_interaction.md
new file mode 100644
index 0000000000..5053f8d9f4
--- /dev/null
+++ b/examples/keras_rs/md/two_stage_rs_with_marketing_interaction.md
@@ -0,0 +1,818 @@
+# Two Stage Recommender System with Marketing Interaction
+
+**Author:** Mansi Mehta
+**Date created:** 26/11/2025
+**Last modified:** 06/12/2025
+**Description:** Recommender System with Ranking and Retrieval model for Marketing interaction.
+
+
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/two_stage_rs_with_marketing_interaction.ipynb) •
[**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/two_stage_rs_with_marketing_interaction.py)
+
+
+
+# **Introduction**
+
+This tutorial demonstrates a critical business scenario: a user lands on a website, and a
+marketing engine must decide which specific ad to display from an inventory of thousands.
+The goal is to maximize the Click-Through Rate (CTR). Showing irrelevant ads wastes
+marketing budget and annoys the user. Therefore, we need a system that predicts the
+probability of a specific user clicking on a specific ad based on their demographics and
+browsing habits.
+
+**Architecture**
+1. **The Retrieval Stage:** Efficiently select an initial set of roughly 10-100
+candidates from millions of possibilities. It weeds out items the user is definitely not
+interested in.
+User Tower: Embeds user features (ID, demographics, behavior) into a vector.
+Item Tower: Embeds ad features (Ad ID, Topic) into a vector.
+Interaction: The dot product of these two vectors represents similarity.
+2. **The Ranking Stage:** It takes the output of the retrieval model and fine-tune the
+order to select the single best ad to show.
+A Deep Neural Network (MLP).
+Interaction: It takes the User Embedding, Ad Embedding, and their similarity score to
+predict a precise probability (0% to 100%) that the user will click.
+
+
+
+
+# **Dataset**
+We will use the [Ad Click
+Prediction](https://www.kaggle.com/datasets/mafrojaakter/ad-click-data) Dataset from
+Kaggle
+
+**Feature Distribution of dataset:**
+User Tower describes who is looking and features contains i.e Gender, City, Country, Age,
+Daily Internet Usage, Daily Time Spent on Site, and Area Income.
+Item Tower describes what is being shown and features contains Ad Topic Line, Ad ID.
+
+In this tutorial, we are going to build and train a Two-Tower (User Tower and Ad Tower)
+model using the Ad Click Prediction dataset from Kaggle.
+We're going to:
+1. **Data Pipeline:** Get our data and preprocess it for both Retrieval (implicit
+feedback) and Ranking (explicit labels).
+2. **Retrieval:** Implement and train a Two-Tower model to generate candidates.
+3. **Ranking:** Implement and train a Neural Ranking model to predict click probabilities.
+4. **Inference:** Run an end-to-end test (Retrieval --> Ranking) to generate
+recommendations for a specific user.
+
+
+```python
+!pip install -q keras-rs
+```
+
+
+```python
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+import pandas as pd
+import keras_rs
+
+from keras import layers
+from concurrent.futures import ThreadPoolExecutor
+from sklearn.model_selection import train_test_split
+from sklearn.preprocessing import MinMaxScaler
+
+```
+
+# **Preparing Dataset**
+
+
+```python
+!pip install -q kaggle
+!kaggle datasets download -d mafrojaakter/ad-click-data --unzip -p ./ad_click_dataset
+```
+
+
+
+
+```python
+data_path = "./ad_click_dataset/Ad Click Data.csv"
+if not os.path.exists(data_path):
+ # Fallback for filenames with spaces or different casing
+ data_path = "./ad_click_dataset/Ad Click Data.csv"
+
+ads_df = pd.read_csv(data_path)
+# Clean column names
+ads_df.columns = ads_df.columns.str.strip()
+# Rename the column name
+ads_df = ads_df.rename(
+ columns={
+ "Male": "gender",
+ "Ad Topic Line": "ad_topic",
+ "City": "city",
+ "Country": "country",
+ "Daily Time Spent on Site": "time_on_site",
+ "Daily Internet Usage": "internet_usage",
+ "Area Income": "area_income",
+ }
+)
+# Add user_id and add_id column
+ads_df["user_id"] = "user_" + ads_df.index.astype(str)
+ads_df["ad_id"] = "ad_" + ads_df["ad_topic"].astype("category").cat.codes.astype(str)
+# Remove nulls and normalize
+ads_df = ads_df.dropna()
+# normalize
+numeric_cols = ["time_on_site", "internet_usage", "area_income", "Age"]
+scaler = MinMaxScaler()
+ads_df[numeric_cols] = scaler.fit_transform(ads_df[numeric_cols])
+
+# Split the train and test datasets
+x_train, x_test = train_test_split(ads_df, test_size=0.2, random_state=42)
+
+
+def dict_to_tensor_features(df_features, continuous_features):
+ tensor_dict = {}
+ for k, v in df_features.items():
+ if k in continuous_features:
+ tensor_dict[k] = tf.expand_dims(tf.constant(v, dtype="float32"), axis=-1)
+ else:
+ v_str = np.array(v).astype(str).tolist()
+ tensor_dict[k] = tf.expand_dims(tf.constant(v_str, dtype="string"), axis=-1)
+ return tensor_dict
+
+
+def create_retrieval_dataset(
+ data_df,
+ all_ads_features,
+ all_ad_ids,
+ user_features_list,
+ ad_features_list,
+ continuous_features_list,
+):
+
+ # Filter for Positive Interactions (Clicks)
+ positive_interactions = data_df[data_df["Clicked on Ad"] == 1].copy()
+
+ if positive_interactions.empty:
+ return None
+
+ def sample_negative(positive_ad_id):
+ all_ad_ids_filtered = [aid for aid in all_ad_ids if aid != positive_ad_id]
+ if not all_ad_ids_filtered:
+ return positive_ad_id
+ neg_ad_id = np.random.choice(all_ad_ids_filtered)
+ return neg_ad_id
+
+ def create_triplets_row(pos_row):
+ pos_ad_id = pos_row.ad_id
+ neg_ad_id = sample_negative(pos_ad_id)
+
+ neg_ad_row = all_ads_features[all_ads_features["ad_id"] == neg_ad_id].iloc[0]
+ user_features_dict = {
+ name: getattr(pos_row, name) for name in user_features_list
+ }
+ pos_ad_features_dict = {
+ name: getattr(pos_row, name) for name in ad_features_list
+ }
+ neg_ad_features_dict = {name: neg_ad_row[name] for name in ad_features_list}
+
+ return {
+ "user": user_features_dict,
+ "positive_ad": pos_ad_features_dict,
+ "negative_ad": neg_ad_features_dict,
+ }
+
+ with ThreadPoolExecutor(max_workers=os.cpu_count() or 8) as executor:
+ triplets = list(
+ executor.map(
+ create_triplets_row, positive_interactions.itertuples(index=False)
+ )
+ )
+
+ triplets_df = pd.DataFrame(triplets)
+ user_df = triplets_df["user"].apply(pd.Series)
+ pos_ad_df = triplets_df["positive_ad"].apply(pd.Series)
+ neg_ad_df = triplets_df["negative_ad"].apply(pd.Series)
+
+ user_features_tensor = dict_to_tensor_features(
+ user_df.to_dict("list"), continuous_features_list
+ )
+ pos_ad_features_tensor = dict_to_tensor_features(
+ pos_ad_df.to_dict("list"), continuous_features_list
+ )
+ neg_ad_features_tensor = dict_to_tensor_features(
+ neg_ad_df.to_dict("list"), continuous_features_list
+ )
+
+ features = {
+ "user": user_features_tensor,
+ "positive_ad": pos_ad_features_tensor,
+ "negative_ad": neg_ad_features_tensor,
+ }
+ y_true = tf.ones((triplets_df.shape[0], 1), dtype=tf.float32)
+ dataset = tf.data.Dataset.from_tensor_slices((features, y_true))
+ buffer_size = len(triplets_df)
+ dataset = (
+ dataset.shuffle(buffer_size=buffer_size)
+ .batch(64)
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+ )
+ return dataset
+
+
+user_clicked_ads = (
+ x_train[x_train["Clicked on Ad"] == 1]
+ .groupby("user_id")["ad_id"]
+ .apply(set)
+ .to_dict()
+)
+
+for u in x_train["user_id"].unique():
+ if u not in user_clicked_ads:
+ user_clicked_ads[u] = set()
+
+AD_FEATURES = ["ad_id", "ad_topic"]
+USER_FEATURES = [
+ "user_id",
+ "gender",
+ "city",
+ "country",
+ "time_on_site",
+ "internet_usage",
+ "area_income",
+ "Age",
+]
+continuous_features = ["time_on_site", "internet_usage", "area_income", "Age"]
+
+all_ads_features = x_train[AD_FEATURES].drop_duplicates().reset_index(drop=True)
+all_ad_ids = all_ads_features["ad_id"].tolist()
+
+retrieval_train_dataset = create_retrieval_dataset(
+ data_df=x_train,
+ all_ads_features=all_ads_features,
+ all_ad_ids=all_ad_ids,
+ user_features_list=USER_FEATURES,
+ ad_features_list=AD_FEATURES,
+ continuous_features_list=continuous_features,
+)
+
+retrieval_test_dataset = create_retrieval_dataset(
+ data_df=x_test,
+ all_ads_features=all_ads_features,
+ all_ad_ids=all_ad_ids,
+ user_features_list=USER_FEATURES,
+ ad_features_list=AD_FEATURES,
+ continuous_features_list=continuous_features,
+)
+```
+