Towards Open-World Feature Extrapolation:
An Inductive Graph Learning Approach
Abstract
We target open-world feature extrapolation problem where the feature space of input data goes through expansion and a model trained on partially observed features needs to handle new features in test data without further retraining. The problem is of much significance for dealing with features incrementally collected from different fields. To this end, we propose a new learning paradigm with graph representation and learning. Our framework contains two modules: 1) a backbone network (e.g., feedforward neural nets) as a lower model takes features as input and outputs predicted labels; 2) a graph neural network as an upper model learns to extrapolate embeddings for new features via message passing over a feature-data graph built from observed data. Based on our framework, we design two training strategies, a self-supervised approach and an inductive learning approach, to endow the model with extrapolation ability and alleviate feature-level over-fitting. We also provide theoretical analysis on the generalization error on test data with new features, which dissects the impact of training features and algorithms on generalization performance. Our experiments over several classification datasets and large-scale advertisement click prediction datasets demonstrate that our model can produce effective embeddings for unseen features and significantly outperforms baseline methods that adopt KNN and local aggregation. The implementation codes are public available at https://github.com/qitianwu/FATE.
1 Introduction
Learning a mapping from observation (a vector of attribute features) to label is a fundamental and pervasive problem in ML community, with extensive applications spanning from classification/regression tasks for tabular data to advertisement click prediction feature-ctr ; lr ; wd ; fm ; fwl , item recommendation feature-rec ; feature-rec2 ; youtube ; pinsage , question answering qa2 ; qa1 , or AI more broadly. Existing approaches focus on a fixed input feature space shared by training and test data. Nevertheless, practical ML systems interact with a dynamic open-world where features are incrementally collected. For example, in recommender/advertisement systems, there often occur new user profile features unseen before for current prediction tasks. Also, with the advances of multi-modal recognition mm-4 and federated learning mm-2 , it is a requirement for a model trained with partial features to incorporate new features from other fields for decision-making on a target task.
A challenge stems from the fact that off-the-shelf neural network models cannot deal with new features without re-training on new data. As shown in Fig. 1(a), a neural network builds a mapping from input features to a hidden representation through a weight matrix in the first layer. Given new features as input, the network needs to be augmented with new weight parameters and cannot map the new features to a desirable position in latent space if without re-training.
However, model re-training would be tricky and bring up several issues. First, re-training a model with both previous and new features would be highly time-consuming and cannot meet the requirement for online systems. Alternatively, one can re-train a trained model only on new features, which will induce risks for over-fitting new data or forgetting previous data.
Different from machines, fortunately, humans are often equipped with the ability for extrapolating to unseen features and distill the knowledge in new information for solving a target task without any re-training on new data. The inherent gap between existing ML approaches and human intelligence raises a research question: Can we design a ML model that is trained on one set of features and able to generalize to combine new unseen features for the same task without further training?

To find desirable solutions to this non-trivial question is challenging. Let us resort to how humans think and act in physical world scenarios. Imagine that we are well trained on predicting a person’s income with one’s age and occupation from historical records. Now we are asked to estimate the income of a new person with his/her age, occ. and education (a new feature). We can modularize the process in brain systems that distill and leverage the knowledge in new observations into four steps, as shown in Fig. 1(b). 1) Perception: the observations are recognized by our senses; 2) Abstraction: the perceived information is aligned with concepts in our cognition; 3) Reasoning: we search similar concepts and observations in our memory to understand and absorb the new knowledge; 4) Decision: with new understanding and abstraction, we make final decisions for the prediction.
The above human’s thinking process inspires the methodology in our paper where we propose a new learning paradigm for open-world feature extrapolation. Our proposed framework contains two modules: a backbone network, which can be a feedforward neural network, and a graph neural network. 1) The backbone network first maps input features to embeddings, which can be seen as perception from observation. 2) We then treat observed data matrix (each row represents a feature vector for one instance) as a feature-data bipartite graph, which explicitly define proximity and locality structures among features and instances. Then a graph neural network is harnessed for neural message passing over adjacent features and instances in a latent space (abstraction). 3) The GNN will inductively compute embeddings for new features based on those of existing ones, mimicking the reasoning process from familiar concepts to new ones in our brain. 4) The newly obtained embeddings that capture both semantics and feature-level relations will be used to obtain hidden representations of new data with unseen features and make final decisions.
To endow the model with ability for extrapolation to new features, we propose two training algorithms, one by self-supervised learning and one by inductive learning. The proposed model and its learning algorithm can easily scale to large-scale dataset (with millions of features and instances) via mini-batch training. We also provide a theoretical analysis on the generalization error on test data with new features, which allows us to dissect the impact of training features and algorithms on generalization performance. To verify our approach, we conduct extensive experiments on several real-world datasets, including six classification datasets from biology, engineering and social domains with diverse features, as well as two large-scale datasets for advertisement click prediction. Our model is trained on training data using partial features and tested on test data with a mixture of seen and unseen features. The results demonstrate that 1) our approach consistently outperform models not using new features for inference; 2) our approach achieves averagely higher Accuracy than baseline methods using KNN, pooling or local aggregation for feature extrapolation; 3) our approach even exceeds models using incremental training on new features, yielding average higher Accuracy.
Our contributions are: 1) We formulate open-world feature extrapolation problem and show that it is feasible to extend neural models for extrapolating to new features without re-training; 2) We propose a new graph-learning model and two training algorithms for feature extrapolation problem; 3) Our theoretical analysis shows that the generalization error for data with new features relies on the number of training features and the randomness in training algorithms; 4) We conduct comprehensive experiments and show the effectiveness, applicality and scalability of proposed method.
2 Methodology
We focus on attribute features as input in this paper. An input instance is a vector where each entry denotes a raw feature (like age, occ., edu., etc.). If is a discrete/categorical raw feature with possible values, its space is an integer set . If is a continuous one, a common practice is to convert it into a discrete feature within space by evenly division ffm or log transformation feature-ctr . We call as cardinality for -th raw feature.
An effective way to handle attribute features is via one-hot encoding wd ; ffm ; deepfm . For with cardinality we convert it into a -dimensional one-hot vector where the unique 1 indexes the value. In this way, one can convert an input into a concatenation of one-hot vectors:
(1) |
Use to denote the -th entry of and we call each as feature in this paper. Assume denotes the number of features and, as a reminder, is the number of raw features.
We next give a formal definition for open-world feature extrapolation problem in this paper: given training data where , and is a set of indices, we aim to learn a model that can generalize to test data where , and is another set of indices. We term as training feature space and as test feature space. We assume 1) the label space is shared by training and test data, and 2) , i.e., test feature space is an extension of training feature space. The feature space expansion stems from two possible causes: 1) there appear new raw features incrementally collected from other fields (i.e., increases) or 2) there appear new values out of the known space of existing raw features (i.e., increases).
2.1 Proposed Model
Our model contains three parts: 1) feature representation that builds a bipartite feature-data graph from input data; 2) a backbone network which is essentially a neural network model that predicts the labels when fed with input data; 3) a GNN model that inductively compute features’ embeddings based on their proximity and local structures to achieve feature extrapolation.
Feature Representation with Graphs. We stack the feature vectors of all the training data as a matrix where . Then we treat each feature and instance as nodes and construct a bipartite graph between them. Formally, we define a node set where with the -th feature and with the -th instance in training set. The binary matrix constitutes an adjacency matrix where the non-zero entries indicate edges connecting two nodes in and , respectively. The induced feature-data bipartite graph will play an important role in our extrapolation approach. The representation is flexible for variable-size feature set, enabling our model to handle test data which gives .
Backbone Networks. We next consider a prediction model as a backbone network that maps data features to predicted label . Without loss of generality, a default choice for is a feedforward neural network. The first layer serves as an embedding layer which shrinks into a -dimensional hidden vector where denotes a weight matrix. The subsequent network (called classifier) is often a stack of neural layers that predicts label . We use the notation to highlight two sets of parameters and .
Notice that the matrix multiplication in the embedding layer is equivalent to a two-step procedure: 1) a lookup of feature embeddings and 2) a permutation-invariant aggregation. More specifically, we consider as a stack of weight vectors where corresponds to the embedding of feature . The non-zero entries in will index the corresponding rows of and induce a set of embeddings where is the embedding given by (i.e., one-hot vector of -th raw feature). Then the hidden vector of -th instance can be obtained by aggregation, i.e., which is permutation-invariant w.r.t. the order of feature embeddings in . A more intuitive illustration is presented in Fig. 2.
The permutation-invariant property opens a way for handling variable-length feature vectors deepset . Essentially, on condition that we have embeddings for input features, we can add them up to get a fixed-dimensional hidden representation as input for the subsequent classifier. Therefore, the problem boils down to learning feature embeddings, especially how to extrapolate for embeddings of new features based on those of existing ones.
Remark. Instead of using sum aggregation, some existing architectures consider concatenation of embedding vectors ’s, which is essentially equivalent to sum aggregation (see Appendix A for more details). Therefore, the permutation-invariant property holds for widely adopted deep models.

GNN for Feature Extrapolation. We proceed to propose a graph neural networks (GNN) model for embedding learning with the feature-data graph. Our key insight is that the bipartite graph explicitly embodies features’ co-occurrence in observed instances, which reflects the proximity among features. Once we conduct message passing for feature embeddings over the graph structures, the embeddings of similar features can be leveraged to compute and update each feature’s embedding. The model can learn to extrapolate for new features’ embeddings using those of existing features with locality structures in a data-driven manner. The message passing over the defined graph representation is inductive w.r.t. variable-sized feature nodes and instance nodes, which enables the model to tackle new feature space with distinct feature sizes and supports.
Specifically, we consider the embeddings as an initial state for node in . The initial states of instance nodes are set as zero vectors with equal dimension as the feature nodes, i.e. . The interaction between two sets of nodes and can be modeled via graph neural networks where the node states in the -th layer are updated by
(2) | |||
where is a weight matrix and we do not use non-linearity since it would degrade the performance empirically. For any new feature in test data, we can set its initial state as a zero vector . The GNN model outputs updated embeddings for feature nodes and we further use them as the feature embeddings in the backbone network. Fig. 3 presents the feedforward computation of proposed model. Formally, with -layer GNN, the GNN network gives updated feature embeddings and then the backbone network outputs prediction where for training, for test and denotes ’s parameters,.

2.2 Model Learning
We next discuss approaches for model training. In order to enable the model to extrapolate for new features, we put forward two useful strategies. 1) Proxy training data: we only use partial features from training set as observed ones for each update. 2) Asynchronous updates: we decouple the training of backbone network and GNN network and using different updating speeds for them in a nested manner (see Fig. 4). Based on these, we proceed to propose two specific training approaches.

Self-supervised Learning with N-fold Splits. To mimic new features in the future, we can mask some observed features and let the model use the remaining features to estimate the embeddings of the masked ones. For a given feature set of data , we consider an n-fold splitting method: in each iteration the features are first randomly shuffled and evenly divided into disjoint subsets, denoted by . We then consider asynchronous updating rule for two networks: each iteration contains times updates for backbone network and one update for GNN model. For the -th update of the backbone, we mask features in and set the initial states of masked features as zero vectors before fed into GNN. The GNN network will use the adjacency matrix to compute updated embeddings for the masked features. The embedding layer will be composed of updated embeddings of masked features and initial embeddings of the remaining features, based on which the backbone network outputs prediction for each and compute the loss function 111Note that we still use supervised labels for loss though we call this approach as self-supervised learning. (where can be cross-entropy for classification). After -step updates for backbone network, we use the accumulated loss to update GNN model. The training procedure will repeat the above process until a given time budget.
Inductive Learning with K-shot Samples. Alternatively, we can sample over the feature set and only expose partial features to the model for each update. For the -th update, we randomly sample raw features222One can also directly sample a certain ratio of features from , which might lead to large variance. from input data which induces a new feature set and extract the corresponding columns of to form a proxy data matrix (where each instance contains features). Then is fed into GNN to obtain updated embeddings of features in , based on which the backbone network outputs prediction for each instance using features in .
By contrast, the n-fold splitting contributes to better training stability since the model is updated on each feature in each iteration, while the inductive learning adds more randomness which can help to enhance model’s generalization. We will further compare them in our experiments.
DropEdge for Regularization. In order to further alleviate over-fitting on training features, we use the DropEdge dropedge to regularize our model. We consider a threshold and randomly set nonzero entries in (for self-supervised) or (for inductive) as zero for each feedforward computation,
(3) |
Scaling to Large Systems. To handle prohibitively large datasets for practical systems, we can divide data matrix into mini-batches along the instance dimension. Then, we feed each mini-batch into the model for once model training (including feature-level sampling/splitting, as shown in Fig. 4) or inference. Since the number of nonzero features for each instance is no more than (a relatively small value), the edge number in each mini-batch will be controlled within (assume a mini-batch contains instances). Hence, the space cost can be effectively controlled using instance-level mini-batch partition. Yet, note that could not be arbitrarily small in order to guarantee sufficient message passing over diverse instances. We present the complete training algorithm in Appendix B where the model is trained end-to-end using self-supervised or inductive learning approaches.
3 Generalization Analysis
In this section, we analyze the generalization error on test data with new features. We simplify the settings for analysis: 1) the backbone network is a two-layer FNN (an embedding layer plus a fully-connected layer ) with sigmoid output; 2) the GNN network is a -layer GCN which takes mean pooling aggregation over neighbored nodes without linear transformation and non-linearity in each layer; 3) the training algorithm is SGD. With above settings, the model can be written as where contains all the ’s that appear in the -hop neighbors of in the feature-data graph, , and is a weight that quantifies influence of on through -layer mean-pooling graph convolution. More details for the derivation are in Appendix C. Also, we focus our analysis on the case of inductive learning and the results can be extended to self-supervised approach, which we leave for future work.
The data generation process can be described as follows. First, features ’s are sampled from an unknown distribution and form a feature set . Then data are sampled from a distribution whose support is over , and define and . Using , the model can be denoted by (with a simplification from ) with loss function .
Recall that in training stage, we randomly partition the instances in into mini-batches with size and each mini-batch further samples raw features to form a feature subset . In each update, the model is exposed to a sub-matrix from as proxy training data and uses it for once feedforward and backward computation. With given training data , we define as a set of all the proxy data sub-matrices that could be exposed to the model during the training
The training process can be seen as a sequence of operations each of which samples a sub-matrix from as proxy data in an i.i.d. manner and computes gradients for one SGD update (more discussions are in Appendix C). Define as a learning algorithm trained on , which gives a trained model simplified as . The generalization error can be defined as
(4) |
where the expectation contains two stages of sampling: 1) a feature set is sampled according to , and 2) data is sampled according to . The empirical risk that our approach optimizes with the training data would be
(5) |
We study the expected generalization gap
(6) |
where the expectation is taken over the randomness of stemming from sampling for SGD updates.
We assume that the loss function is Lipschitz-continuous and smooth w.r.t. the model output . Concretely, we have 1) and 2) . Such condition can be satisfied by widely used loss functions such as cross-entropy and MSE. Then we have the following result (see Appendix C for proof).
Theorem 1.
Assume the loss function is bounded by . For a learning algorithm trained on data with iterations of SGD updates, with probability at least , we have
(7) |
The generalization gap depends on the number of raw features in training data and the size of . The latter is determined by configuration of proxy training data, particular, sampling over training features. If the sampling introduces more randomness (e.g. ), would become larger, contributing to tighter gap. However, a large would also lead to large variance in training and amplify optimization error. Therefore, there exists a trade-off w.r.t. how to sample/split observed features in training stage. Furthermore, the generalization gap also depends on , and a larger would result in looser bound (since one often has ). This is because a larger would require to deal with more features. As the network becomes wider and complex, it would be more prone for over-fitting.
4 Experiments
We apply our model FATE (for FeATure Extrapolation Networks) on real-world datasets. First, we consider six classification datasets from UCI Machine Learning Repository uci-dataset : Gene, Protein, Robot, Drive, Calls and Github, as collected from domains like biology, engineering and social networks. The feature numbers are ranged from 219 to 4006 and the instance numbers vary from 1080 to 58509. We consider two large-scale datasets Avazu and Criteo from real-world online advertisement system whose goal is to predict the Click-Through Rate (CTR) of exposed advertisement to users. The two datasets have million clicking/non-clicking records as instances and million features. More dataset information and implementation details are in Appendix D and E, respectively.
4.1 Experiment on UCI Datasets

Setup. We randomly split all the instances into training/validation/test data with the ratio 6:2:2. Then we randomly select a certain ratio () of features as observed ones and use the remaining as unobserved ones. The model is trained with the observed features of training instances and tested with all the features of testing data. We adopt Accuracy as metric for datasets with more than two classes (Gene, Protein, Robot, Drive and Calls) and ROC-AUC for Github with two classes.
Implementation. We specify FATE in the following ways. 1) Backbone: a 3-layer feedforward NN. 2) GNN: a 4-layer GCN. 3) Training: self-supervised learning with n-fold splits. Several baselines are considered for comparison and their architectures are all specified as a 3-layer feedforward NN. First, Base-NN, Average-NN, Pooling-NN and KNN-NN are all trained on training instances with observed features. Then Base-NN only uses test instances’ observed features for inference. Average-NN uses averaged embeddings of observed features as those of unobserved ones. Pooling-NN (resp. KNN-NN) computes embeddings for unobserved features via replacing our GNN with mean pooling aggregation over neighborhoods (resp. KNN aggregation over all the observed ones). Furthermore, we consider Oracle-NN using all the features of training data for training and INL-NN that is first trained on training data with observed features and then re-trained on training data with the remaining features.
Results and Discussions. Fig. 5 reports the mean Accuracy/ROC-AUC of five trials with different ratios of observed features ranging from 0.3 to 0.8. FATE achieves averagely higher Accuracy and higher ROC-AUC over Base-NN which uses partial features for inference. The improvements are statistically significant under confidence level. The results show that FATE can learn effective embeddings for new features that contribute to better performance for classification. Furthermore, FATE achieves averagely higher Accuracy and higher ROC-AUC over baselines Average-NN, KNN-NN and Pooling-NN. These baseline methods perform worse than Base-NN especially when observed features are few, which suggests that directly aggregating embeddings of observed features for extrapolation would degrade the performance. By contrast, FATE possesses superior capability for extrapolating to new unseen features. Even with observed features the model is able to distill the useful knowledge from unobserved features without re-training, providing decent classification performance. Notably, compared with INL-NN, FATE even achieves higher accuracy in most cases with a Accuracy improvement on average. The possible reason is that INL-NN is prone for over-fitting on new data and forgetting the previous one. Finally, FATE achieves very close performance to Oracle-NN when using sufficient observed features and can even slightly exceeds it with features in Gene, Robot and Github. In fact, the GNN network in FATE can not only achieve feature extrapolation, but also capture feature-level relations, which could be another merit of our method.
4.2 Experiment on CTR Prediction
Dataset | Backbone | Model | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | Overall |
---|---|---|---|---|---|---|---|---|---|---|---|
Avazu | NN | Base | 0.666 | 0.680 | 0.691 | 0.694 | 0.699 | 0.703 | 0.705 | 0.705 | 0.693 0.012 |
Pooling | 0.655 | 0.671 | 0.683 | 0.683 | 0.689 | 0.694 | 0.697 | 0.697 | 0.684 0.011 | ||
FATE | 0.689 | 0.699 | 0.708 | 0.710 | 0.715 | 0.720 | 0.721 | 0.721 | 0.710 0.010 | ||
DeepFM | Base | 0.675 | 0.684 | 0.694 | 0.697 | 0.699 | 0.706 | 0.708 | 0.706 | 0.697 0.009 | |
Pooling | 0.666 | 0.676 | 0.685 | 0.685 | 0.688 | 0.693 | 0.694 | 0.694 | 0.685 0.009 | ||
FATE | 0.692 | 0.702 | 0.711 | 0.714 | 0.718 | 0.722 | 0.724 | 0.724 | 0.713 0.010 | ||
Criteo | NN | Base | 0.761 | 0.761 | 0.763 | 0.763 | 0.765 | 0.766 | 0.766 | 0.766 | 0.764 0.002 |
Pooling | 0.761 | 0.762 | 0.764 | 0.763 | 0.766 | 0.767 | 0.768 | 0.768 | 0.765 0.001 | ||
FATE | 0.770 | 0.769 | 0.771 | 0.772 | 0.773 | 0.774 | 0.774 | 0.774 | 0.772 0.001 | ||
DeepFM | Base | 0.772 | 0.771 | 0.772 | 0.772 | 0.774 | 0.774 | 0.774 | 0.774 | 0.773 0.001 | |
Pooling | 0.772 | 0.772 | 0.773 | 0.774 | 0.776 | 0.776 | 0.776 | 0.776 | 0.774 0.002 | ||
FATE | 0.781 | 0.780 | 0.782 | 0.782 | 0.784 | 0.784 | 0.784 | 0.784 | 0.783 0.001 |
Setup. We split the dataset in chronological order to simulate real-world cases. For Avazu dataset which contains time information in ten days, we use the data of first/second day for training/validation and the data of the third to tenth days for test. For Criteo dataset whose records are given in temporal order, we split the dataset into ten continual subsets with equal size and use the first/second subset for training/validation and the third to tenth subsets for test. With such data splitting, we can naturally obtain validation/test data with a mixture of seen and unseen features in the training data (the new features come from new values out of the known range of raw features). For Avazu/Criteo, there are 0.6/1.3 million features in training data, 0.2/0.4 million new features in validation data and totally 1.1/0.8 million new features in all the test splits. We use ROC-AUC as evaluation metric.
Implementation Details. We consider two specifications for our backbone: 1) a 3-layer feedfoward NN, and 2) DeepFM deepfm , a widely used model for advertisement click prediction considering inter-feature interactions over NN. The GNN is a 2-layer GraphSAGE model. The training method is inductive learning with k-shot samples and mini-batch partition. Also, we compare with baselines Base and Pooling. The KNN method would suffer from scalability issue in the two large datasets.
Results and Discussions. Table 1 reports the ROC-AUC results for different test splits, which show that FATE significantly outperforms Base and Pooling in all the test splits using NN and DeepFM as backnones. Overall, compared with Base, FATE improves the ROC-AUC by 0.017/0.016 (resp. 0.008/0.01) with NN/DeepFM as the backbone on Avazu (resp. Criteo). Note that even an improvement of 0.004 for ROC-AUC is considered significant in click prediction tasks wd ; deepfm . The results show that FATE can combine useful information in new features collected in the future for the target task without further training and has promising power for enhancing the real-world systems interacting with open world. Compared with Pooling, FATE achieves significant AUC improvements on two datasets. The reason is that directly using average pooling to replace the GNN convolution would lead to limited capacity and weakens its ability for effective concept abstraction and reasoning.
4.3 Further Discussions
Scalability. We study model’s scalability w.r.t. different batch sizes and number of features in Fig. 6 and 7 which show that our training and inference time/space scale linearly on Criteo dataset.
Ablation Studies. Table 3 also provides ablation studies, which show that 1) the DropEdge operation can help regularize the training and bring up higher test accuracy; 2) using asynchronous updates for two networks leads to performance gain over joint training; 3) the n-fold splitting and k-shot sampling will outperform each other in different cases and both exceeds leave-one-out partition. Table 4 compares using different sampling size for inductive learning approach on Avazu and Criteo, which verify our theoretical analysis in Section 3. See more discussions in Appendix F.
Visualization. Fig. 6 visualizes the produced feature embeddings by FATE-NN and Oracle-NN. It unveils two interesting insights that can interpret why FATE can sometimes outperform Oracle that uses all the features for training. First, FATE-NN’s produced embeddings for observed and unobserved features possess more dissimilar distributions in latent space, compared to Oracle-NN. Notice that features’ embeddings are used for the backbone network to compute intermediate hidden representation for each instance. Such phenomenon implies that FATE manages to extract more informative knowledge from new features. Second, the embeddings of FATE-NN form some particular structures (clusters, lines or curves) rather than uniformly distribute over the 2-D plane like Oracle-NN. The reason is that the GNN network leverages locality structures among features for further abstraction which explicitly encodes feature-level relations and could help downstream classification.

5 Connection to Other Learning Paradigms
Our introduced problem setting, open-world feature extrapolation (OFE), can be treated as an instantiation of out-of-distribution generalization ood-classic-2 ; ood-classic-3 or domain shift problem, focusing on distribution shift led by feature space expansion. We next discuss the relationships of our problem setting and our model FATE with domain adaption (DA), continual learning (CL), open-set learning (OSL) and zero-shot learning (ZSL). In general, OFE is orthogonal to these problems and opens a new direction that can potentially have promising intersections with the well-established ones.
Domain Adaption adapts a model trained on source domain to target with different distribution da-old1 ; da-old2 ; da-old3 ; da-nn1 ; da-nn2 ; da-nn3 ; da-nn4 . Our problem OFE is different from DA in two aspects: 1) the label space/distribution of training and test data is the same for OFE, while DA often mostly considers different label distributions for source and target domains; 2) OFE focus on combining new features that are related to the current task, while DA considers different tasks from different domains.
Continual Learning, or lifelong learning, aims at enabling a single model to learn from a stream of data from different tasks that cannot be seen at one time cl-old1 ; cl-1 ; cl-2 ; cl-survey . By contrast, there are two-fold differences of OFE. First, OFE does not allow finetuning or further retraining on new data, which can be more challenging than CL. Second, CL mostly assume each piece of data in the stream is from different tasks with different labels to handle. The key challenge of CL is the catastrophic forgetting cl-old1 ; cl-survey that requires the model to balance a trade-off between previous and new tasks, while FATE is free from such issue in nature since we do not require incremental learning.
Open-set Learning is another line of researches that relate with us. Differently, open-set recognition mostly focus on expansion of label sets openset-cv1 ; openset-cv2 ; openset-cv3 ; openset-cv4 ; openset-cv5 ; openset-nlp . To our knowledge, we are the first to study feature set expansion, formulate it as OFE and further solve it via graph learning.
Zero-shot Learning. Our problem setting is also linked with few-shot/zero-shot learning. In NLP domains, some studies focus on dealing with rare entities exposed in limited times or new entities unseen by training nlp-1 ; nlp-3 . Similar problems are also encoutered and explored in cold-start recommender systems where there are also new users/items unseen before recsys-1 ; recsys-2 ; recsys-3 . One common nature of these works aim at inferring the embeddings for new entities based on some ‘held-out’ ones. With a similar end and distinct methodological aspect, a recent study recsys-3 explores a new possibility via learning a latent graph between existing entities (users) and newly arrived ones through attention mechanism for inductive representation learning. The core technical contributions of our work lie in the unique problem setting which stays focused on feature space expansion (domain shift) and the proposed feature-level sampling/partition training strategy, backed up with our theoretical insights.
Extension and Outlooks. Our work can be extended to solve more problems and push the development in broader areas. First, the input feature vectors can be replaced by feature maps given by CNN or word/sentence embeddings by Transformer transformer , for handling multi-modal data in the context of federated learning mm-2 or multi-view learning mm-4 ; mm-6 ; mm-7 . Admittedly, our formulation assumes multi-hot feature vectors as input, which is a common practice for handling attribute features but not often the case for other data format (like vision or texts). For practitioners who would like to apply FATE to broader areas, one can extend our graph representation and GNN model to directly handle continuous features by treating feature values as edge weights as is done by graphlearning-missing . Second, as shown in our experiments, the classification layers can be replaced by more complex models with inter-feature interactions ffm ; xdeepfm ; nfm ; fwl or advanced architectures danser ; featevolve for more sufficient feature-wise interaction and improving the expressiveness.
6 Conclusion
We present a new framework to address new features unseen in training, by formulating it as the open-world feature extrapolation problem. We target the problem via graph representation learning by treating observed data as a feature-data graph and further harness GNN to inductively compute embeddings for new features with those of existing ones, mimicking abstraction and reasoning process in human’s brain. We also propose two training strategies for effective feature extrapolation learning. Our theoretical results show that generalization error depends on training features and learning algorithms. Experiments verify its effectiveness and scalability to large-scale systems.
Potential Societal Impacts. When learning mapping from features to labels, the model is at risk of focusing on dominant features from majority groups and ignoring scarce features from minority ones. Potential extended works of much significance are to develop debiased methods for feature extrapolation. We believe AI models can be guided to promote social justice and well-being.
References
- [1] A. Asuncion and D. Newman. Uci machine learning repository. 2007.
- [2] B. B. Avants, N. J. Tustison, and J. R. Stone. Similarity-driven multi-view embeddings from high-dimensional biomedical data. Nature Computational Science, 1:143–152, 2021.
- [3] D. Baptista, P. G. Ferreira, and M. Rocha. Deep learning for drug response prediction in cancer. Briefings Bioinform, 22(1):360–379, 2021.
- [4] S. Ben-David, J. Blitzer, K. Crammer, A. Kulesza, F. Pereira, and J. W. Vaughan. A theory of learning from different domains. Machine learning, 79(1-2):151–175, 2010.
- [5] A. Bendale and T. E. Boult. Towards open world recognition. In Conference on Computer Vision and Pattern Recognition, pages 1893–1902, 2015.
- [6] A. Bendale and T. E. Boult. Towards open set deep networks. In Conference on Computer Vision and Pattern Recognition, pages 1563–1572, 2016.
- [7] G. Blanchard, G. Lee, and C. Scott. Generalizing from several related classification tasks to a new unlabeled sample. In Advances in Neural Information Processing Systems (NeurIPS), pages 2178–2186, 2011.
- [8] J. Blitzer, K. Crammer, A. Kulesza, F. Pereira, and J. Wortman. Learning bounds for domain adaptation. In Advances in Neural Information Processing Systems, pages 129–136, 2007.
- [9] O. Bousquet and A. Elisseeff. Stability and generalization. Journal of Machine Learning Research, 2:499–526, 2002.
- [10] H. C, G. KJ, and C. KJ. Self-organizing feature maps identify proteins critical to learning in a mouse model of down syndrome. PLoS ONE, 10(6), 2015.
- [11] W. Chen, M. Chang, E. Schlinger, W. Y. Wang, and W. W. Cohen. Open question answering over tables and text. In International Conference on Learning Representations, 2021.
- [12] W. Chen, H. Zha, Z. Chen, W. Xiong, H. Wang, and W. Y. Wang. Hybridqa: A dataset of multi-hop question answering over tabular and textual data. In Conference on Empirical Methods in Natural Language Processing, pages 1026–1036, 2020.
- [13] H. Cheng, L. Koc, J. Harmsen, T. Shaked, T. Chandra, H. Aradhye, G. Anderson, G. Corrado, W. Chai, M. Ispir, R. Anil, Z. Haque, L. Hong, V. Jain, X. Liu, and H. Shah. Wide & deep learning for recommender systems. In Workshop on Deep Learning for Recommender Systems, pages 7–10, 2016.
- [14] P. Covington, J. Adams, and E. Sargin. Deep neural networks for youtube recommendations. In Conference on Recommender Systems, pages 191–198, 2016.
- [15] Y. Ganin and V. S. Lempitsky. Unsupervised domain adaptation by backpropagation. In International Conference on Machine Learning, pages 1180–1189, 2015.
- [16] Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. Marchand, and V. S. Lempitsky. Domain-adversarial training of neural networks. The Journal of Machine Learning Research, 17:59:1–59:35, 2016.
- [17] H. Guo, R. Tang, Y. Ye, Z. Li, and X. He. Deepfm: A factorization-machine based neural network for CTR prediction. In International Joint Conference on Artificial Intelligence, pages 1725–1731, 2017.
- [18] W. L. Hamilton, Z. Ying, and J. Leskovec. Inductive representation learning on large graphs. In Advances in Neural Information Processing Systems, pages 1024–1034, 2017.
- [19] B. Hao, J. Zhang, H. Yin, C. Li, and H. Chen. Pre-training graph neural networks for cold-start users and items representation. In ACM International Conference on Web Search and Data Mining, pages 265–273, 2021.
- [20] M. Hardt, B. Recht, and Y. Singer. Train faster, generalize better: Stability of stochastic gradient descent. In International Conference on Machine Learning, pages 1225–1234, 2016.
- [21] M. Hassen and P. K. Chan. Learning a neural-network-based representation for open set recognition. In International Conference on Data Mining, pages 154–162, 2020.
- [22] X. He and T. Chua. Neural factorization machines for sparse predictive analytics. In International Conference on Research and Development in Information Retrieval, pages 355–364, 2017.
- [23] Y. Juan, Y. Zhuang, W. Chin, and C. Lin. Field-aware factorization machines for CTR prediction. In Conference on Recommender Systems, pages 43–50, 2016.
- [24] Y. Juan, Y. Zhuang, W. Chin, and C. Lin. Field-aware factorization machines for CTR prediction. In Conference on Recommender Systems, pages 43–50, 2016.
- [25] T. N. Kipf and M. Welling. Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations, 2017.
- [26] Y. Koren, R. M. Bell, and C. Volinsky. Matrix factorization techniques for recommender systems. Computer, 42(8):30–37, 2009.
- [27] Z. Li, J. Zhang, Y. Gong, Y. Yao, and Q. Wu. Field-wise learning for multi-field categorical data. In Advances in Neural Information Processing Systems, pages 6639–6649, 2020.
- [28] J. Lian, X. Zhou, F. Zhang, Z. Chen, X. Xie, and G. Sun. xdeepfm: Combining explicit and implicit feature interactions for recommender systems. In International Conference on Knowledge Discovery & Data Mining, pages 1754–1763, 2018.
- [29] P. P. Liang, T. Liu, Z. Liu, R. Salakhutdinov, and L. Morency. Think locally, act globally: Federated learning with local and global representations. CoRR, abs/2001.01523, 2020.
- [30] L. Logeswaran, M. Chang, K. Lee, K. Toutanova, J. Devlin, and H. Lee. Zero-shot entity linking by reading entity descriptions. In A. Korhonen, D. R. Traum, and L. Màrquez, editors, Conference of the Association for Computational Linguistics, pages 3449–3460, 2019.
- [31] M. Long, Z. Cao, J. Wang, and M. I. Jordan. Conditional adversarial domain adaptation. In Advances in Neural Information Processing Systems, pages 1647–1657, 2018.
- [32] Y. Mansour, M. Mohri, and A. Rostamizadeh. Domain adaptation: Learning bounds and algorithms. In Conference on Learning Theory, 2009.
- [33] K. Muandet, D. Balduzzi, and B. Schölkopf. Domain generalization via invariant feature representation. In International Conference on Machine Learning (ICML), pages 10–18, 2013.
- [34] P. Oza and V. M. Patel. C2AE: class conditioned auto-encoder for open-set recognition. In Conference on Computer Vision and Pattern Recognition, pages 2307–2316. Computer Vision Foundation / IEEE, 2019.
- [35] G. I. Parisi, R. Kemker, J. L. Part, C. Kanan, and S. Wermter. Continual lifelong learning with neural networks: A review. Neural Networks, 113:54–71, 2019.
- [36] G. I. Parisi, J. Tani, C. Weber, and S. Wermter. Lifelong learning of human actions with deep neural network self-organization. Neural Networks, 96:137–149, 2017.
- [37] S. Rendle. Factorization machines. In International Conference on Data Mining, pages 995–1000, 2010.
- [38] M. Richardson, E. Dominowska, and R. Ragno. Predicting clicks: estimating the click-through rate for new ads. In The Web Conference, pages 521–530, 2007.
- [39] A. V. Robins. Catastrophic forgetting in neural networks: the role of rehearsal mechanisms. In First New Zealand International Two-Stream Conference on Artificial Neural Networks and Expert Systems, ANNES ’93, Dunedin, New Zealand, November 24-26, 1993, pages 65–68. IEEE, 1993.
- [40] Y. Rong, W. Huang, T. Xu, and J. Huang. Dropedge: Towards deep graph convolutional networks on node classification. In International Conference on Learning Representations, 2020.
- [41] B. Rozemberczki, C. Allen, and R. Sarkar. Multi-scale attributed node embedding. Journal of Complex Networks, 9(2), 2021.
- [42] A. A. Rusu, N. C. Rabinowitz, G. Desjardins, H. Soyer, J. Kirkpatrick, K. Kavukcuoglu, R. Pascanu, and R. Hadsell. Progressive neural networks. CoRR, abs/1606.04671, 2016.
- [43] T. Schick and H. Schütze. Rare words: A major problem for contextualized embeddings and how to fix it by attentive mimicking. In AAAI Conference on Artificial Intelligence, pages 8766–8774, 2020.
- [44] L. Shu, H. Xu, and B. Liu. DOC: deep open classification of text documents. In Conference on Empirical Methods in Natural Language Processing, pages 2911–2916, 2017.
- [45] Y. H. Tsai, P. P. Liang, A. Zadeh, L. Morency, and R. Salakhutdinov. Learning factorized multimodal representations. In International Conference on Learning Representations. OpenReview.net, 2019.
- [46] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems, pages 5998–6008, 2017.
- [47] S. Verma and Z. Zhang. Stability and generalization of graph convolutional neural networks. In International Conference on Knowledge Discovery & Data Mining, pages 1539–1548, 2019.
- [48] Q. Wu, L. Jiang, X. Gao, X. Yang, and G. Chen. Feature evolution based multi-task learning for collaborative filtering with social trust. In International Joint Conference on Artificial Intelligence, pages 3877–3883, 2019.
- [49] Q. Wu, H. Zhang, X. Gao, P. He, P. Weng, H. Gao, and G. Chen. Dual graph attention networks for deep latent representation of multifaceted social effects in recommender systems. In The Web Conference, pages 2091–2102, 2019.
- [50] Q. Wu, H. Zhang, X. Gao, J. Yan, and H. Zha. Towards open-world recommendation: An inductive model-based collaborative filtering approach. In International Conference on Machine Learning, pages 11329–11339, 2021.
- [51] T. Wu, E. K. I. Chio, H. Cheng, Y. Du, S. Rendle, D. Kuzmin, R. Agarwal, L. Zhang, J. R. Anderson, S. Singh, T. Chandra, E. H. Chi, W. Li, A. Kumar, X. Ma, A. Soares, N. Jindal, and P. Cao. Zero-shot heterogeneous transfer learning from recommender systems to cold-start search retrieval. In ACM International Conference on Information and Knowledge Management, pages 2821–2828, 2020.
- [52] M. Xu, R. Jin, and Z. Zhou. Speedup matrix completion with side information: Application to multi-label learning. In Advances in Neural Information Processing Systems, pages 2301–2309, 2013.
- [53] R. Ying, R. He, K. Chen, P. Eksombatchai, W. L. Hamilton, and J. Leskovec. Graph convolutional neural networks for web-scale recommender systems. In International Conference on Knowledge Discovery & Data Mining, pages 974–983, 2018.
- [54] R. Yoshihashi, W. Shao, R. Kawakami, S. You, M. Iida, and T. Naemura. Classification-reconstruction learning for open-set recognition. In Conference on Computer Vision and Pattern Recognition, pages 4016–4025, 2019.
- [55] J. You, X. Ma, D. Y. Ding, M. J. Kochenderfer, and J. Leskovec. Handling missing data with graph representation learning. In Advances in Neural Information Processing Systems (NeurIPS), 2020.
- [56] M. Zaheer, S. Kottur, S. Ravanbakhsh, B. Póczos, R. Salakhutdinov, and A. J. Smola. Deep sets. In Advances in Neural Information Processing Systems, pages 3391–3401, 2017.
- [57] H. Zhao, S. Zhang, G. Wu, J. M. F. Moura, J. P. Costeira, and G. J. Gordon. Adversarial multiple source domain adaptation. In Advances in Neural Information Processing Systems, pages 8568–8579, 2018.
Appendix A More Discussions for Permutation-Invariant Property of Embedding Layers
In Section 2.1, we mentioned that the embedding layer in the backbone network can be equivalently seen as a combination of embedding lookup and a sum aggregation which is permutation-invariant w.r.t. the order of input features. We provide an illustration for this in Fig. 2.
Equivalence between Concatenation and Sum Aggregation. To support the remark argument in Section 2.1, we next illustrate the equivalence between concatenation of features’ embeddings and sum aggregation/pooling over features’ embeddings. Assume we have feature embeddings for instance . We concat all the embeddings as a vector and feed it into a neural layer to obtain . Notice that the weight matrix can be decomposed into sub-matrices where . If we consider sum aggregation/pooling over , i.e. , the subsequent neural layer would be a weight matrix with dimension . We can set it as and will easily obtain . Hence, the concatenation plus a fully-connected layer is equivalent to sum pooling plus a fully-connected layer. This observation indicates that our reasoning in the maintext can be applied to general neural network-based models for attribute features and enable them to handle input vectors with variable-length features.
Appendix B Training Algorithms
We present the training algorithms for our model in Alg. 1 where the model is trained end-to-end via self-supervised learning or inductive learning approaches.
Appendix C Analysis of Generalization Error
We provide a complete discussion and proof for analysis on generalization error of our approach. Some notations are repeatedly defined in order for a self-contained presentation in this section. Recall that we focus our analysis on the case of inductive learning with k-shot sampling approach. Also, we simplify the model as :1) the backbone network is a two-layer FFN (an embedding layer plus a fully-connected layer ) with sigmoid output; 2) the GNN network is a -layer GCN which takes mean pooling aggregation over neighbored nodes without linear transformation and non-linearity in each layer; 3) the training algorithm is SGD.
Derivation for model function. With our settings in Section 3, we write the model as
(8) |
where is a set which contains ’s that appear in the -hop neighbors of in the feature-data graph, , and is a weight that quantifies influence of on through -layer mean-pooling graph convolution. Here we provide the detailed derivation. In fact, the embedding layer in the backbone can be seen as a one-layer GCN convolution using sum pooling without linear transformation and non-linearity, which can be denoted by where (recall that is treated as an adjacency matrix of the feature-data graph in our model). The GNN model, which is a -layer GCN with mean pooling without linear transformation and non-linearity in each layer, can be denoted by , where with and with . Hence, we have . Let and where denotes the a weight that quantifies influence between instance and through -layer GCN. Converting the global view of graph convolution into a local view for each node’s ego-network, we can obtain .
With given training data , we define as a set of all the data sub-matrices that could be sampled and exposed to the model during training
The SGD training can be seen as a sequence of operations each of which picks an instance from in an i.i.d. manner as a proxy training data and leverage it to compute updating gradient. We further introduce which removes the -th sub-matrix and which replaces -th sub-matrix by another one. Specifically, we have
Justification of the i.i.d. Sampling. In fact, for our inductive learning approach in Section 2.2, the observed features for each proxy data are randomly sampled. The feature-level sampling at one time can be seen as times i.i.d. sampling from all the raw features without replacement. Denote as a set of all the raw features in training set, as a set of distinct indices in and denotes a subset of raw features with indices from . Obviously, there are different configurations for (or ) in total. We can equivalently treat once feature-level sampling as a one-time i.i.d. sampling from a set of candidates which contains index sets and each index set contains indices from . Next we discuss two cases.
1) If we do not consider instance-level mini-batch partition, then the set will consist of sub-matrices. Specifically, the -th sub-matrix is induced by which extracts the columns (corresponding to features generated by raw features in ) of .
2) If we use instance-level mini-batch partition, the case would be a bit more complicated. First of all, the instance-level partition is not a strictly i.i.d. sampling process over training instances since their exists dependency among different mini-batches in one epoch. Yet, in practice, the batch size is very large (e.g. in our experiment), so the number of mini-batches in one epoch is much smaller than , which allows us to neglect the dependency in one epoch. Furthermore, since the instance-level selection is dependent of feature-level sampling, the whole sampling process for proxy data can be seen as a series of i.i.d. sampling over sub-matrices of , which consists of the set in this case.
Next, we recall the generalization gap of our interests. The generalization error is defined as
(9) |
where the expectation contains two stages of sampling: 1) a feature set is sampled according to , and 2) data is sampled according to . The empirical risk that our approach optimizes with the training data would be
(10) |
Then the expected generalization gap would be
(11) |
where the expectation is taken over the randomness of that stems from sampling in SGD.
We next prove the result in Theorem 1 in our maintext. Our proof is based on algorithmic stability analysis [9], following similar lines of reasoning in [20, 47]. The main idea of the stability analysis is to bound the output difference of a loss function from a single data point perturbation. Differently, in our case, the ‘data point’ is a data sub-matrix in . Therefore, our proof can be seen an extension of stability analysis to matrix data or graph as input. The proof can be divided into two parts. First, we derive a generalization error bound on condition of -uniform stability of the learning algorithm, Then we prove the bound for based on our model architecture and SGD training.
C.1 Generalization error with uniform stability condition
We first introduce the definition for uniform stability of a randomized learning algorithm as a building block of our proof. A randomized learning algorithm is -uniform stable with regard to loss function if it satisfies
(12) |
We first prove a generalization bound using the uniform stability as a condition and then we prove that the learning algorithm in our case satisfies the condition.
Theorem 2.
Assume a randomized algorithm is -uniform stable with a bounded loss function . Then with probability at-least (), over the random draw of , we have
(13) |
Proof.
Using triangle inequality, the stability property in (12) yields,
(14) |
We will use McDiarmid’s concentration inequality for the following proof. Let be a random variable set and . If it satisfies
(15) |
then we have
(16) |
Recall that data are assumed to be i.i.d. sampled, so we have (assuming )
(17) |
Using above equation and the -uniform stability we have
(18) |
Also we have the following inequalities,
(19) |
(20) |
Letting and using (19) and (20), we obtain
(21) |
Based on the above fact, we can apply the result of (16),
(22) |
Letting and using (18), we obtain the following result and conclude the proof.
(23) |
∎
C.2 Deriving bound for
We proceed to prove our main result in Theorem 1 by deriving the bound for based on the SGD algorithm and our GNN model. Let and denote the weight matrix of the classifier in the backbone network. Recall that our model is . Hence, we have
(24) |
We need to bound the two terms in (24). First, notice that for , it satisfies and and the graph convolution with mean pooling induce the fact that . Using the inequality of arithmetic and geometric means, we have .
We proceed to bound the second term by considering the randomness of SGD. We can define as model parameters and we need to derive bound for . Then define a sequence of model parameters where denotes the model parameters learned by SGD on with the updating in -th step as
(25) |
Similarly, denotes a sequence of model parameters learned by SGD on . We then derive bound for by considering two cases.
First, at step , SGD picks data and , i.e., exists in both and . This case will happen with probability . The derivative of model output is
(26) |
Using the fact , we have
(27) |
Second, at step , SGD picks and , i.e., picked by the algorithm on and picked by the algorithm on are distinct. This case would happen with probability . We have
(28) |
where the last inequality is due to .
Appendix D Dataset Information
Dataset | Domain | #Instances | #Raw Feat. | Cardinality | #0-1 Feat. | #Class |
Gene | Life | 3190 | 60 | 46 | 287 | 3 |
Protein | Life | 1080 | 80 | 28 | 743 | 8 |
Robot | Computer | 5456 | 24 | 9 | 237 | 4 |
Drive | Computer | 58509 | 49 | 9 | 378 | 11 |
Calls | Life | 7195 | 10 | 410 | 219 | 10 |
Github | Social | 37700 | - | 4006 | 2 | |
Avazu | Ad. | 40,428,967 | 22 | 51611749 | 2,018,025 | 2 |
Criteo | Ad. | 45,840,617 | 39 | 5541311 | 2,647,481 | 2 |
D.1 Dataset Information
We present detailed information for our used datasets concerning the data collection, preprocessing and statistic information.
UCI datasets. The six datasets are provided by UCI Machine Learning repository [1]. They are from different domains, including biology, engineering and social networks. Gene dataset contains 60 DNA sequence elements, and the task is to recognize exon/intron boundaries of DNA. Protein dataset [10] consists of the expression levels of 77 proteins/protein modifications, genotype, treatment type and behavior, and the task is to identify subsets of proteins that are discriminant between eight classes of mice. Robot dataset is collected as a robot navigates through a room following a wall with 24 ultrasound sensor readings, and the task is to predict the robot behavior. Drive dataset is extracted from electric current drive signals with 49 attributes, and the task is to identify 11 different classes with different conditions. Calls dataset was created by segmenting audio records belonging to 4 different families, 8 genus, and 10 species, and the task is to identify the class of species. Github dataset [41] is a large social network of GitHub developers with their location, repositories starred, employer and e-mail address, and the task is to predict whether the GitHub user is a web or a machine learning developer.
The six UCI datasets have diverse statistics. Overall, they contain thousands of instances and dozens of raw features with a mix up of categorial and continuous ones. The categorical raw features have cardinality ranged from 2 to 12. As mentioned in Section 2.1, the cardinality means the number of possible values for a discrete feature. For continuous features in each dataset (if exist), we first normalize the values into 0-mean and 1-standard-deviation distribution and then hash them into 10 buckets with evenly partition between the maximum and the minimum. Then each raw feature can be converted into one-hot representation. After converting all the features into binary ones we get up to hundreds of 0-1 features for each dataset. Table 2 summarizes the basic information for each dataset.
CTR prediction datasets. The two click-through rate (CTR) prediction datasets have millions of instances and dozens of raw features with diverse cardinality. The goal of CTR prediction task is to estimate the probability that a user will click on an advertisement with the user’s profile features and the ad’s content features. In specific, Criteo333http://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset/ is a widely used public benchmark dataset for developing CTR models, which includes 45 million users’ click records, 13 continuous raw features and 26 categorical ones 444In computational advertisement community, the raw features (e.g. site category, device id, device type, app domain, etc.) are often called fields. We call them raw features in our paper to keep the notation self-contained.. We follow [24, 23] and use log transformation to convert the continuous features into discrete ones. Avazu555https://www.kaggle.com/c/avazu-ctr-prediction is another publicly accessible dataset for CTR prediction, which contains users’ mobile behaviors including whether a displayed mobile ad is clicked by a user or not. It has 40 millions users’ click records, 23 categorical raw feature spanning from user/device features to ad attributes (all are encoded to remove user identity information). The cardinality of different raw features for these two datasets are very diverse, ranging from 5 to a million. The raw features with very large cardinality include some id features, e.g. device id, site id, app id, etc. For each dataset, we convert each raw feature into one-hot representations and obtain 0-1 features. For features appearing less than 4 times we group them as one feature. After preprocessing, we obtain nearly 2 million 0-1 features for Avazu and Criteo as shown in Table 2.
D.2 Dataset Splits
UCI datasets. For each of UCI datasets, we first randomly partition all the instances into training/validation/test sets according to the ratio of 6:2:2. Then we randomly select a certain ratio ( in our experiments) of features as observed features and use the remaining as unobserved ones. The model is trained with observed features of training instances, validated with observed features of validation instances and tested with all the features of test instances.
CTR prediction datasets. As illustrated in Section 4, for Avazu/Criteo we split all the instances into ten folds in chronological order. Then we use the first fold for training, second fold for validation, and third to tenth folds for test. In such way, the validation data and test data will naturally contain new features not appeared in training data. Here we provide more illustration about this. As mentioned above, Avazu dataset contains 23 categorial raw features and some of them have very large cardinality. For example, the cardinality of raw features app id and device id are 5481 and 381763, respectively. In practical systems, there will be new apps introduced and new devices observed by the system as time goes by, and they play as new values out of the known range of existing raw features, which consist of new 0-1 features that are not unseen by the model (as introduced in the beginning of Section 2.1). Since we chronologically divide the dataset into training/validation/test sets, the validation and test sets would both contain a mixture of features seen in training set and new features unseen in training. Concretely, for Avazu, there are totally 618411 features in training set, 248614 new features (unseen by training data) in validation set, and totally 1151000 new features (unseen by both training and validation sets) in all the test sets. For Criteo, there are totally 1340248 features in training set, 472023 new features (unseen by training data) in validation set, and totally 835210 new features (unseen by both training and validation sets) in all the test sets.
Appendix E Implementation Details
We present implementation details for our experiments for reproducibility. We implement our model as well as all the baselines with Python 3.8, Pytorch 1.7 and Pytorch Geometric 1.6. The experiments are all run on a RTX 2080Ti, except for our scalability test in Section 4.3 where we use a RTX 8000.
E.1 Details for UCI experiments
Architectures. For experiments on UCI datasets, the network architecture for our backbone network is
-
•
A three-layer neural network with hidden size 8 in each layer.
-
•
The activation function is ReLU.
-
•
The output layer is a softmax function for multi-class classification or sigmoid for two-class classification.
The architecture for our GNN network is
-
•
A four-layer GCN [25] network with hidden size 8 in each layer.
-
•
Adding self-loop and using normalization for graph convolution in each layer.
-
•
No activation unit is used.
Training Details. We adopt self-supervised learning approach with n-fold splitting. Concretely, in each epoch, we feed the whole training data matrix into the model and randomly divide all the observed features into disjoint sets . Then a nested optimization is considered: 1) we update the backbone network with steps where in the -th step, we mask observed features in ; 2) then we update the GNN network with one step using the accumulated loss of the steps. The training procedure will repeat the above process until a given budget of 200 epochs. Also, in each epoch, the validation loss is averaged over -fold data where for the -th fold the features in are masked and the model will use the remaining observed features for prediction. Finally, we report the test accuracy achieved by the epoch that gives the minimum logloss on validation dataset.
Hyperparameters. Other hyper-parameters are searched with grid search on validation dataset. We use the same hyperparameter settings for six datasets, which indicates that our model is dataset agnostic in some senses. The settings and searching space are as follows:
-
•
The learning rates , are searched within . We set and .
-
•
The ratio for DropEdge is searched within . We set .
-
•
The fold number for data partition is searched within . We set .
Baselines. All the baselines are implemented with a three-layer neural network, the same as the backbone network in our model. The baselines are all trained with a given budget of 200 epochs, and we report the test accuracy achieved by the epoch that gives the minimum logloss on validation dataset. The difference of them lies in the ways for leveraging observed and unobserved (new) features in training and inference. The detailed information for baseline methods is as follows.
-
•
Base-NN. Use observed features of training instances for model training, and observed features of validation/test instances for model validation/test.
-
•
Oracle-NN. User all the features of training instances for model training, and all the features of validation/test instances for model validation/test.
-
•
INL-NN. The training process contains two stages. In the first stage, we train the model with initialized parameters using observed features of training instances for 200 epochs and save the model at the epoch that gives the minimum logloss on validation dataset (with observed features). In the second stage, we load the saved model in the first stage, train it using unobserved features of training instances for 200 epochs and report the test accuracy (using all the features) achieved by the epoch that gives the minimum logloss on validation dataset (using all the features).
-
•
Average-NN. Use observed features of training instances for model training. In test stage, we average the embeddings of observed features as the embeddings of unobserved features. Then the model would use all the features of test instances for inference (by using the trained embeddings of observed features and estimated embeddings of unobserved ones).
-
•
Pooling-NN. Use observed features of training instances for model training. In test stage, we replace the GNN model in FATE with mean pooling over neighbored nodes. Specifically, the embeddings of unobserved features are obtained by non-parametric message passing using mean pooling over the feature-data bipartite graph.
-
•
KNN-NN. Use observed features of training instances for model training. In test stage, we compute the Jaccard similarity scores between any pair of observed and unobserved features. Then for each unobserved feature, its embedding is obtained by taking average of the embeddings of the observed features with top 20% Jaccard similarities as the target unobserved feature.
E.2 Details for Avazu/Criteo experiments
Architectures. For experiments on Criteo and Avazu datasets, we consider two specifications for the backbone network. First, we specify it as a feedforward NN, whose architecture is
-
•
A three-layer neural network with hidden size 10-400-400-1.
-
•
The activation function is ReLU unit except the last layer using sigmoid.
-
•
We use BatchNorm and Dropout with probability 0.5 in each layer.
Second, we specify it as DeepFM network [17], which also contains an embedding layer and a subsequent classification layer. The embedding layer is an embedding lookup which maps each nonzero index in to an embedding, denoted as where denotes the embedding for the -th raw feature of instance . The subsequent classification layer can be denoted by
(32) |
where , FNN is a feedforward neural network and FM is a factorization machine which can be denoted as
(33) |
For our model FATE-DeepFM, we use the GNN model to compute feature embeddings based on which we use the input feature vector of an instance to obtain and and then plug into the subsequent classification layer.
The architecture for our GNN network is
-
•
A two-layer GraphSAGE [18] network with hidden size 10 in each layer.
-
•
No activation unit.
Training Details. We adopt inductive learning approach with k-shot sampling for training our model. Furthermore, we use instance-level mini-batch partition to control space cost. Concretely, in each epoch, we first randomly shuffle all the training instances and partition them into mini-batches with size . Then for each mini-batch, we consider a training iteration where the backbone network is updated with steps and the GNN model is updated with one step. For -th step update for the backbone, we uniformly sample raw features from existing ones, obtain a new feature set induced by the sampled raw features, and extract the corresponding columns in the data matrix to form a proxy data, i.e., a sub-matrix from . We then use the proxy data matrix to update the backbone. After -step updates for the backbone, we update the GNN model with the accumulated loss of the steps.
All the models including FATE and baselines are trained with a given budget of 100 epochs. For every 10 training iteration, we compute the validation loss and ROC-AUC on validation dataset. Finally, we report the test ROC-AUC achieved by the epoch that gives the highest ROC-AUC on validation dataset.
Hyperparameters. Other hyper-parameters are searched with grid search on validation dataset. The settings and searching space are as follows:
-
•
The learning rates , are searched within . We set and for NN as backbone. For DeepFM as backbone, we set and on Criteo, and on Avazu.
-
•
The ratio for DropEdge is searched within . We set .
-
•
The batch size is searched within . We set .
-
•
The sampling size for data partition is searched within for Avazu and for Criteo. We set for Avazu and for Criteo.
Appendix F More Experiment Results
We supplement more experiment results as further discussions of our method, including salability tests and ablation studies.
F.1 Scalability Test
We conduct experiments for scalability test on Criteo dataset. The scalability experiment is deployed on a RTX 8000 GPU with 48GB memory (though our comparison experiments in Section 4.1 and 4.2 require less than 12GB memory for each trial).
Impact of Batch Sizes. We statistic the running time per mini-batch for training and inference on Criteo dataset in Fig. 6(a) and (b) where the batch size is changed from 1e5 to 1e6. The results are taken average over 20 mini-batches. As we can see, as the batch size increases, the training time and inference time both increase linearly, which depicts that our model has linear scalability w.r.t. the number of instances for each update and inference. Also, in Fig. 6(c) and (d), we present the GPU memory cost for training and inference on Criteo dataset. As we can see, the space cost of FATE also increases linearly with respect with batch sizes. Indeed, as discussed in Section 3.2, the time and space complexity of FATE is using mini-batch training where is relative small value (up to a hundred). Hence, the empirical results verify our analysis.
Impact of Feature Numbers. We also discuss the model’s scalability concerning different feature numbers, i.e. the dimension of feature vectors for training data ( for test data). There are totally 39 raw features in Criteo dataset and we only use of them for experiments, which induces features, and also compare the training/inference time per mini-batch and GPU memory costs. The results are shown in Fig. 7(a)-(d). We can see that as the feature number increases, the time and space costs both go up in linear trends, which indicates FATE has linear scalability w.r.t. feature number . In fact, more feature numbers would require larger model size (for feature embeddings) and induce larger computational graph due to the increase of ; also, the increase of would also require more training/inference time based on our complexity analysis.
F.2 Ablation Studies
We next conduct ablation studies for some key components in our framework and discuss their impacts on our model. The results are shown in Table 3 and Table 4.
Effectiveness of DropEdge. In Table 3, we compare with not using DropEdge regularization in training stage. The results show that FATE consistently achieve superior accuracy throughout six datasets, which demonstrate the effectiveness of DropEdge regularization that can help to alleviate the over-fitting on training features.
Effectiveness of Asynchronous Updates. We also compare our asynchronous updates (alternative fast updates for Backbone network and slow updates for GNN) with directly using end-to-end jointly training of two networks. The results show that FATE with asynchronous updates can outperform joint training approach over a large margin, which verify the effectiveness of our proposed asynchronous updating rule. The reason is that using asynchronous updates can decouple the training for two networks and further help two models learn useful information from observed data. Also, we observe that using slow updates for GNN network with the accumulated loss of several data splits can stabilize its training and alleviate the over-fitting.
Comparison between Training Approaches. We further investigate the k-fold splitting and n-fold sampling strategies used in our training approaches in Table 3. Recall that in UCI datasets, we adopt the self-supervised learning approach for training. Here we compare our used n-fold splitting with leave-one-out, which leave out partial features as a fixed set for masking in training, and k-shot sampling, which randomly sample training features as observed ones and mask the remaining for each update. The results show that the n-fold splitting and k-shot sampling strategies both provide superior performance in six datasets. Furthermore, when using different ’s, the relative performance of n-fold splitting and k-shot sampling approaches diverge in different cases. Overall, we found using n-fold splitting with or work the best on average. In fact, the n-fold splitting and k-shot sampling both play as a role in mimicking new features and exposing partial observed features to the model in training. The difference is that n-fold splitting guarantees that in each iteration the model can be updated on each feature in training set while the k-shot sampling introduces more randomness. Unlike UCI datasets, in two large-scale datasets Criteo and Avazu where we adopt the inductive learning approach for training, we found using k-shot sampling works consistently better than n-fold splitting. One possible reason is that k-shot sampling can increase the diversity of proxy data (containing partial features and partial instances) used for each training update and can presumably help the model to overcome feature-level over-fitting in large datasets. Such results are consistent with our theoretical generalization error analysis in Section 3.
Impact of Sampling Sizes. We next study the impact of sampling size on the model performance. We use different ’s for inductive learning on Avazu and Criteo. The results are shown in Table 4. As we can see, as increases, the training AUC goes up, which demonstrates that larger sampling size can help for optimization since it reduces the variance of sampling and enhances training stability. Furthermore, it is not always beneficial to increase . When it becomes large enough and close to the number of raw features , the model would suffers from over-fitting. The results further demonstrate that large sampling size would lead to feature-level over-fitting, which echoes our theoretical results in Section 3. Recall that Theorem 1 shows that model’s generalization gap depends on the randomness in sampling over training features. Here when is large, there will be less randomness from feature-level data partition, which will degrade model’s generalization ability.
Models | Gene | Protein | Robot | Drive | Calls | Github |
---|---|---|---|---|---|---|
w/o DropEdge | 0.9226 | 0.9031 | 0.8062 | 0.5261 | 0.9760 | 0.8688 |
End-to-end Joint | 0.9257 | 0.8963 | 0.8454 | 0.1073 | 0.9762 | 0.7557 |
FATE (ours) | 0.9345 | 0.9178 | 0.8815 | 0.6440 | 0.9839 | 0.8743 |
Leave-one-out | 0.8564 | 0.6574 | 0.7641 | 0.4448 | 0.9334 | 0.8533 |
n-fold split () | 0.8884 | 0.8426 | 0.8888 | 0.5910 | 0.9851 | 0.8723 |
n-fold split () | 0.9345 | 0.9178 | 0.8815 | 0.6440 | 0.9839 | 0.8743 |
n-fold split () | 0.9298 | 0.8398 | 0.8359 | 0.5234 | 0.9514 | 0.8771 |
k-shot sample () | 0.9404 | 0.9046 | 0.8839 | 0.5559 | 0.9812 | 0.8712 |
k-shot sample () | 0.9379 | 0.9102 | 0.8802 | 0.6060 | 0.9819 | 0.8712 |
k-shot sample () | 0.9304 | 0.8778 | 0.8568 | 0.5408 | 0.9611 | 0.8722 |
Dataset | Train | Val | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | |
---|---|---|---|---|---|---|---|---|---|---|---|
Avazu | 11 | 0.7815 | 0.7369 | 0.6853 | 0.6950 | 0.7058 | 0.7093 | 0.7137 | 0.7186 | 0.7183 | 0.7193 |
14 | 0.7842 | 0.7399 | 0.6896 | 0.6989 | 0.7080 | 0.7091 | 0.7142 | 0.7190 | 0.7201 | 0.7210 | |
17 | 0.7902 | 0.7433 | 0.6894 | 0.6995 | 0.7082 | 0.7105 | 0.7156 | 0.7203 | 0.7215 | 0.7216 | |
20 | 0.7978 | 0.7420 | 0.6872 | 0.6978 | 0.7080 | 0.7091 | 0.7146 | 0.7201 | 0.7201 | 0.7202 | |
Criteo | 16 | 0.7955 | 0.7725 | 0.7669 | 0.7666 | 0.7688 | 0.7695 | 0.7714 | 0.7721 | 0.7722 | 0.7722 |
21 | 0.7988 | 0.7752 | 0.7699 | 0.7695 | 0.7714 | 0.7721 | 0.7736 | 0.7739 | 0.7741 | 0.7744 | |
24 | 0.8005 | 0.7758 | 0.7701 | 0.7694 | 0.7712 | 0.7727 | 0.7732 | 0.7745 | 0.7740 | 0.7743 | |
27 | 0.8025 | 0.7747 | 0.7698 | 0.7683 | 0.7711 | 0.7713 | 0.7727 | 0.7743 | 0.7734 | 0.7744 | |
30 | 0.8057 | 0.7750 | 0.7690 | 0.7678 | 0.7701 | 0.7708 | 0.7725 | 0.7735 | 0.7723 | 0.7739 |







