首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >分层褶皱的嵌套交叉验证

分层褶皱的嵌套交叉验证
EN

Stack Overflow用户
提问于 2020-11-04 10:46:12
回答 1查看 306关注 0票数 1

我正在尝试使用scikit学习管道和嵌套交叉验证来实现一个随机森林回归器。数据集是关于房价的,有几个特性(有些是数字的,其他分类的)和一个连续的目标变量(median_house_value)。

代码语言:javascript
复制
Data columns (total 10 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   longitude           20640 non-null  float64
 1   latitude            20640 non-null  float64
 2   housing_median_age  20640 non-null  float64
 3   total_rooms         20640 non-null  float64
 4   total_bedrooms      20433 non-null  float64
 5   population          20640 non-null  float64
 6   households          20640 non-null  float64
 7   median_income       20640 non-null  float64
 8   median_house_value  20640 non-null  float64
 9   ocean_proximity     20640 non-null  object 

我决定手动创建两个分层的5折叠分裂(嵌套cv的内部、外部循环)。分层是基于median_income特性的一个修改版本:

代码语言:javascript
复制
df.insert(9, "income_cat", 
                  pd.cut(df["median_income"],bins=[0., 1.5, 3.0, 4.5, 6., np.inf], labels=[1,2,3,4,5]))

这是分裂褶皱的代码

代码语言:javascript
复制
cv1_5 = StratifiedShuffleSplit(n_splits = 5, test_size = .2, random_state = 42)
cv1_splits = []

# create first 5 stratified folds indices
for train_index, test_index in cv1_5.split(df, df["income_cat"]):
    cv1_splits.append((train_index, test_index))

cv2_5 = StratifiedShuffleSplit(n_splits = 5, test_size = .2, random_state = 43)
cv2_splits = []

# create second 5 stratified folds indices
for train_index, test_index in cv2_5.split(df, df["income_cat"]):
    cv2_splits.append((train_index, test_index))
    
# set initial dataset
X = df.drop("median_house_value", axis=1)
y = df["median_house_value"].copy()

这是预处理管道

代码语言:javascript
复制
# create preprocess pipe
preprocess_pipe = Pipeline(
    [
        ("ctransformer", ColumnTransformer([
                ( 
                    "num_pipe", 
                    Pipeline([
                        ("imputer", SimpleImputer(strategy="median")),
                        ("scaler", StandardScaler())
                    ]), 
                    list(X.select_dtypes(include=[np.number]))
                ),
                ( 
                    "cat_pipe", 
                    Pipeline([
                        ("encoder", OneHotEncoder()),
                    ]), 
                    ["ocean_proximity"])
            ])
        ),
    ]
)

这是最后的管道(包括预处理管道)。

代码语言:javascript
复制
pipe = Pipeline([
    ("preprocess", preprocess_pipe),
    ("model", RandomForestRegressor())
])

我使用嵌套交叉验证来调优最终管道的超参数并计算泛化误差。

这是参数网格

代码语言:javascript
复制
param_grid = [
    {
        "preprocess__ctransformer__num_pipe__imputer__strategy": ["mean","median"],
        "model__n_estimators": [3, 10, 30, 50, 100, 150, 300], "model__max_features": [2,4,6,8]
    }
]

这是最后一步

代码语言:javascript
复制
grid_search = GridSearchCV(pipe, param_grid, cv = cv1_splits, 
    scoring = "neg_mean_squared_error", 
    return_train_score = True)

clf = grid_search.fit(X, y)

generalization_error = cross_val_score(clf.best_estimator_, X = X, y = y, cv = cv2_splits)
generalization_error

现在,出现了故障(前面代码片段的下面两行):

如果我遵循scikit学习说明(链接 ),我应该写:

代码语言:javascript
复制
generalization_error = cross_val_score(clf, X = X, y = y, cv = cv2_splits, scoring = "neg_mean_squared_error")
    generalization_error

不幸的是调用cross_val_score(clf,X=X.)给出一个错误(索引超出了训练/测试分裂的),泛化错误数组只包含NaNs。

另一方面,如果我这样写的话:

代码语言:javascript
复制
generalization_error = cross_val_score(clf.best_estimator_, X = X, y = y, cv = cv2_splits, scoring = "neg_mean_squared_error")
        generalization_error

脚本运行得完美无缺,我能够看到包含分数的泛化错误数组。我能坚持做最后一种方式吗,还是整个过程有问题?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-11-10 08:07:19

对我来说,这里的问题可能是使用cv1_splitscv2_splits,而不是cv1_5cv2_5 (特别是使用cv1_splits造成问题)。

通常,cross_val_score()clf估计器的克隆上调用fit();在这种情况下,它是一个GridSearchCV估计器,可以安装到几个X_inner_train集上(X按照cv1_splits分层的子集,X的相同维数-参见这里中的符号)。作为由X构建的cv1_splits,它包含的索引与X维一致,但可能与X_inner_train维数不一致。

相反,通过将cv1_5传递给GridSearchCV估计器,估计器本身负责协调地分割内部训练集(参考参见这里 )。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64678567

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档