一、题目:
二、代码
import csv
import numpy as np
import pandas as pd
'''
1、数据处理(sex\one-hot encoding)[训练集/测试集 共同的数据处理部分]
2、训练集的使用
(1)计算两个类别各自的均值向量,以及协方差,训练集中两个类别各自的样本数
(2)计算w和b(在w^T=(u1-u2)^T的时候,数据使用的对应,就确定了谁是类1,谁是类2)
3、测试集的使用
(1)对于每一个测试样本x(即:每一行数据),根据公式、w、b,计算得到z
(2)计算sigmoid(-z),若大于0.5,说明就是类2;若小于0.5,说明就是类1,进而实现预测
其他说明:
为了使数据处理更清晰,将x和y的分开来
'''
def data_process_x(raw_data):
# remove sex and y
if "income" in raw_data:
data = raw_data.drop(['income', 'sex'], axis=1)
else:
data = raw_data.drop(['sex'], axis=1)
# insert sex
data.insert(0, 'sex', (raw_data['sex'] == ' Female').astype(np.int))
# one-hot encoding
object_data_column_list = [column for column in data.columns if data[column].dtype == 'object']
non_object_data_column_list = [column for column in data.columns if column not in object_data_column_list]
# flatten data(column)
object_data = data[object_data_column_list]
non_object_data = data[non_object_data_column_list]
# pd.get_dummies
# 第一行的表头 它会理解嘛?
object_data = pd.get_dummies(object_data)
# concatenate
data = pd.concat([object_data, non_object_data]
版权声明:本文为qq943686211原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。