deepFM广泛应用于推荐等稀疏数据的场景,发现一个的三方库,用自有数据测试了下,但是数据集基本上都是连续数据,可能没有体现deepFM的优势。
import pandas as pd
import torch
from sklearn.metrics import log_loss, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from deepctr_torch.inputs import SparseFeat, DenseFeat, get_feature_names
from deepctr_torch.models import *
1、读写数据
data0 = pd.read_csv('alldata_typeB.csv')
data = data0[data0['tag1'].isin(['train','test'])]
dense_features = [f for f in data if f not in ['uid','lable','recall_date','month','tag1']]
# data[sparse_features] = data[sparse_features].fillna('-10086', )
data[dense_features] = data[dense_features].fillna(0, )
target = ['lable']
2、数据预处理及模型训练
# for feat in sparse_features:
# lbe = LabelEncoder()
# data[feat] = lbe.fit_transform(data[feat])
mms = MinMaxScaler(feature_range=(0, 1))
data[dense_features] = mms.fit_transform(data[dense_features])
# 这里是比较关键的一步,因为我们需要对类别特征进行Embedding,所以需要告诉模型每一个特征组有多少个embbedding向量,我们通过pandas的nunique()方法统计。
fixlen_feature_columns = [DenseFeat(feat, 1,)
for feat in dense_features]
dnn_feature_columns = fixlen_feature_columns
linear_feature_columns = fixlen_feature_columns
feature_names = get_feature_names(
linear_feature_columns + dnn_feature_columns)
#最后,我们按照上一步生成的特征列拼接数据
train, test = train_test_split(data, test_size=0.2)
train_model_input = {name: train[name] for name in feature_names}
test_model_input = {name: test[name] for name in feature_names}
# 检查是否可以使用gpu
device = 'cpu'
use_cuda = True
if use_cuda and torch.cuda.is_available():
print('cuda ready...')
device = 'cuda:0'
# 初始化模型,进行训练和预测
model = DeepFM(linear_feature_columns=linear_feature_columns, dnn_feature_columns=dnn_feature_columns, task='binary',
l2_reg_embedding=1e-5, device=device)
model.compile("adagrad", "binary_crossentropy",
metrics=["binary_crossentropy", "auc"],)
model.fit(train_model_input, train[target].values,
batch_size=256, epochs=30, validation_split=0.2, verbose=2)
pred_ans = model.predict(test_model_input, 256)
print("")
print("test LogLoss", round(log_loss(test[target].values, pred_ans), 4))
print("test AUC", round(roc_auc_score(test[target].values, pred_ans), 4))
训练结果
cpu
Train on 78035 samples, validate on 19509 samples, 305 steps per epoch
Epoch 1/30
11s - loss: 0.6476 - binary_crossentropy: 0.6476 - auc: 0.6475 - val_binary_crossentropy: 0.6305 - val_auc: 0.6717
Epoch 2/30
12s - loss: 0.6269 - binary_crossentropy: 0.6269 - auc: 0.6807 - val_binary_crossentropy: 0.6236 - val_auc: 0.6815
Epoch 3/30
11s - loss: 0.6214 - binary_crossentropy: 0.6214 - auc: 0.6885 - val_binary_crossentropy: 0.6212 - val_auc: 0.6875
Epoch 4/30
12s - loss: 0.6182 - binary_crossentropy: 0.6182 - auc: 0.6929 - val_binary_crossentropy: 0.6207 - val_auc: 0.6894
Epoch 5/30
12s - loss: 0.6157 - binary_crossentropy: 0.6157 - auc: 0.6958 - val_binary_crossentropy: 0.6180 - val_auc: 0.6909
Epoch 6/30
13s - loss: 0.6138 - binary_crossentropy: 0.6138 - auc: 0.6982 - val_binary_crossentropy: 0.6187 - val_auc: 0.6911
Epoch 7/30
13s - loss: 0.6118 - binary_crossentropy: 0.6118 - auc: 0.6999 - val_binary_crossentropy: 0.6160 - val_auc: 0.6934
Epoch 8/30
13s - loss: 0.6106 - binary_crossentropy: 0.6106 - auc: 0.7020 - val_binary_crossentropy: 0.6154 - val_auc: 0.6943
Epoch 9/30
13s - loss: 0.6094 - binary_crossentropy: 0.6094 - auc: 0.7034 - val_binary_crossentropy: 0.6174 - val_auc: 0.6934
Epoch 10/30
13s - loss: 0.6082 - binary_crossentropy: 0.6082 - auc: 0.7050 - val_binary_crossentropy: 0.6146 - val_auc: 0.6961
Epoch 11/30
13s - loss: 0.6071 - binary_crossentropy: 0.6071 - auc: 0.7060 - val_binary_crossentropy: 0.6142 - val_auc: 0.6951
Epoch 12/30
13s - loss: 0.6063 - binary_crossentropy: 0.6063 - auc: 0.7068 - val_binary_crossentropy: 0.6170 - val_auc: 0.6959
Epoch 13/30
13s - loss: 0.6052 - binary_crossentropy: 0.6052 - auc: 0.7085 - val_binary_crossentropy: 0.6150 - val_auc: 0.6958
Epoch 14/30
13s - loss: 0.6045 - binary_crossentropy: 0.6045 - auc: 0.7095 - val_binary_crossentropy: 0.6127 - val_auc: 0.6974
Epoch 15/30
13s - loss: 0.6036 - binary_crossentropy: 0.6036 - auc: 0.7105 - val_binary_crossentropy: 0.6137 - val_auc: 0.6981
Epoch 16/30
13s - loss: 0.6029 - binary_crossentropy: 0.6029 - auc: 0.7112 - val_binary_crossentropy: 0.6129 - val_auc: 0.6984
Epoch 17/30
13s - loss: 0.6023 - binary_crossentropy: 0.6024 - auc: 0.7121 - val_binary_crossentropy: 0.6134 - val_auc: 0.6978
Epoch 18/30
13s - loss: 0.6013 - binary_crossentropy: 0.6014 - auc: 0.7130 - val_binary_crossentropy: 0.6146 - val_auc: 0.6973
Epoch 19/30
13s - loss: 0.6007 - binary_crossentropy: 0.6007 - auc: 0.7143 - val_binary_crossentropy: 0.6129 - val_auc: 0.6987
Epoch 20/30
13s - loss: 0.6001 - binary_crossentropy: 0.6001 - auc: 0.7144 - val_binary_crossentropy: 0.6116 - val_auc: 0.6995
Epoch 21/30
13s - loss: 0.5994 - binary_crossentropy: 0.5994 - auc: 0.7161 - val_binary_crossentropy: 0.6122 - val_auc: 0.6987
Epoch 22/30
13s - loss: 0.5987 - binary_crossentropy: 0.5986 - auc: 0.7170 - val_binary_crossentropy: 0.6122 - val_auc: 0.6991
Epoch 23/30
13s - loss: 0.5979 - binary_crossentropy: 0.5979 - auc: 0.7174 - val_binary_crossentropy: 0.6118 - val_auc: 0.6987
Epoch 24/30
13s - loss: 0.5975 - binary_crossentropy: 0.5974 - auc: 0.7186 - val_binary_crossentropy: 0.6120 - val_auc: 0.6995
Epoch 25/30
13s - loss: 0.5965 - binary_crossentropy: 0.5965 - auc: 0.7191 - val_binary_crossentropy: 0.6111 - val_auc: 0.6993
Epoch 26/30
13s - loss: 0.5958 - binary_crossentropy: 0.5958 - auc: 0.7203 - val_binary_crossentropy: 0.6113 - val_auc: 0.6994
Epoch 27/30
13s - loss: 0.5952 - binary_crossentropy: 0.5952 - auc: 0.7208 - val_binary_crossentropy: 0.6158 - val_auc: 0.6989
Epoch 28/30
13s - loss: 0.5946 - binary_crossentropy: 0.5946 - auc: 0.7217 - val_binary_crossentropy: 0.6115 - val_auc: 0.7006
Epoch 29/30
13s - loss: 0.5942 - binary_crossentropy: 0.5942 - auc: 0.7226 - val_binary_crossentropy: 0.6108 - val_auc: 0.7014
Epoch 30/30
13s - loss: 0.5933 - binary_crossentropy: 0.5933 - auc: 0.7232 - val_binary_crossentropy: 0.6143 - val_auc: 0.6991
test LogLoss 0.6192
test AUC 0.6942
这个数据集的样本分布很均衡,比较容易训练,但最终效果auc不如树模型lgb,后续用其他数据集继续观察。
版权声明:本文为jin_tmac原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。