具有逻辑回归的分类任务的R,statmodels,sklearn的比较

具有逻辑回归的分类任务的R,statmodels,sklearn的比较,第1张

具有逻辑回归的分类任务的R,statmodels,sklearn的比较

我遇到了类似的问题,最终在/ r /
MachineLearning
上发布了有关该问题的信息。事实证明,差异可以归因于数据标准化。如果数据标准化,则scikit-
learn所使用的任何方法来查找模型的参数都将产生更好的结果。scikit-
learn有一些讨论预处理数据(包括标准化)的文档,可在此处找到。

结果
Number of 'default' values : 333Intercept: [-6.12556565]Coefficients: [[ 2.73145133  0.27750788]]Confusion matrix[[9629   38] [ 225  108]]Score          0.9737Precision      0.7397Recall         0.3243
# scikit-learn vs. R# http://stackoverflow.com/questions/28747019/comparison-of-r-statmodels-sklearn-for-a-classification-task-with-logistic-regimport pandas as pdimport sklearnfrom sklearn.linear_model import LogisticRegressionfrom sklearn.metrics import confusion_matrixfrom sklearn import preprocessing# Data is available here.Default = pd.read_csv('https://d1pqsl2386xqi9.cloudfront.net/notebooks/Default.csv', index_col = 0)Default['default'] = Default['default'].map({'No':0, 'Yes':1})Default['student'] = Default['student'].map({'No':0, 'Yes':1})I = Default['default'] == 0print("Number of 'default' values : {0}".format(Default[~I]['balance'].count()))feats = ['balance', 'income']Default[feats] = preprocessing.scale(Default[feats])# C = 1e6 ~ no regularization.classifier = LogisticRegression(C = 1e6, random_state = 42)classifier.fit(Default[feats], Default['default'])  #fit classifier on whole baseprint("Intercept: {0}".format(classifier.intercept_))print("Coefficients: {0}".format(classifier.coef_))y_true = Default['default']y_pred_cls = classifier.predict_proba(Default[feats])[:,1] > 0.5confusion = confusion_matrix(y_true, y_pred_cls)score = float((confusion[0, 0] + confusion[1, 1])) / float((confusion[0, 0] + confusion[1, 1] + confusion[0, 1] + confusion[1, 0]))precision = float((confusion[1, 1])) / float((confusion[1, 1] + confusion[0, 1]))recall = float((confusion[1, 1])) / float((confusion[1, 1] + confusion[1, 0]))print("nConfusion matrix")print(confusion)print('n{s:{c}<{n}}{num:2.4}'.format(s = 'Score', n = 15, c = '', num = score))print('{s:{c}<{n}}{num:2.4}'.format(s = 'Precision', n = 15, c = '', num = precision))print('{s:{c}<{n}}{num:2.4}'.format(s = 'Recall', n = 15, c = '', num = recall))


欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5647551.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存