<- function(df, impute, remove_outliers, mod) {
fit_model if (impute) {
<- some_imputation_function(df)
df
}
if (remove_outliers) {
<- function_for_removing_outliers(df)
df
}
lm(mod, data = df)
}
This assignment will challenge your function writing abilities. I’m not going to lie, these functions are difficult but well within your reach. I do, however, want to recognize that not everyone is interested in being a “virtuoso” with their function writing. So, there are two options for this week’s lab:
- Option 1: Complete this lab assignment in search of virtuoso status with your function writing
- Option 2: Complete one of the difficult functions (Exercise 1 or Exercise 2) and complete the “Alternative Lab 6”.
Setting the Stage
My number one use case for writing functions and iteration / looping is to perform some exploration or modeling repeatedly for different “tweaked” versions. For example, our broad goal might be to fit a linear regression model to our data. However, there are often multiple choices that we have to make in practice:
- Keep missing values or fill them in (imputation)?
- Filter out outliers in one or more variables?
We can map these choices to arguments in a custom model-fitting function:
impute
: TRUE or FALSEremove_outliers
: TRUE or FALSE
A function that implements the analysis and allows for variation in these choices:
Helper Functions
Exercise 1: Write a function that removes outliers in a dataset. The user should be able to supply the dataset, the variables to remove outliers from, and a threshold on the number of SDs away from the mean used to define outliers. Hint 1: You will need to calculate a z-score to filter the values! Hint 2: You might want to consider specifying a default value (e.g., 3) for sd_thresh
.
library(tidyverse)
use("glue", "glue")
use("cli", "cli_warn")
<- function(data, ..., sd_thresh = 3) {
remove_outliers |>
data filter(
if_all(
c(...),
\(col) {if (!is.numeric(col)) {
# TODO: use a native function instead of `deparse(substitute())`
cli_warn(glue(
"{deparse(substitute(col))} is not a numeric column and will be ignored."
))return(TRUE)
}abs((col - mean(col, na.rm = TRUE)) / sd(col, TRUE)) <= sd_thresh
}
)
) }
Testing Your Function!
## Testing how your function handles multiple input variables
remove_outliers(diamonds, price, x, y, z)
# A tibble: 52,689 × 10
carat cut color clarity depth table price x y z
<dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
4 0.29 Premium I VS2 62.4 58 334 4.2 4.23 2.63
5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
6 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
7 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
8 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
9 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
10 0.23 Very Good H VS1 59.4 61 338 4 4.05 2.39
# ℹ 52,679 more rows
## Testing how your function handles an input that isn't numeric
remove_outliers(diamonds, price, color)
Warning: There was 1 warning in `filter()`.
ℹ In argument: `&...`.
Caused by warning:
! color is not a numeric column and will be ignored.
# A tibble: 52,734 × 10
carat cut color clarity depth table price x y z
<dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
4 0.29 Premium I VS2 62.4 58 334 4.2 4.23 2.63
5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
6 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
7 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
8 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
9 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
10 0.23 Very Good H VS1 59.4 61 338 4 4.05 2.39
# ℹ 52,724 more rows
## Testing how your function handles a non-default sd_thresh
remove_outliers(diamonds, price, x, y, z, sd_thresh = 2)
# A tibble: 50,099 × 10
carat cut color clarity depth table price x y z
<dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
4 0.29 Premium I VS2 62.4 58 334 4.2 4.23 2.63
5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
6 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
7 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
8 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
9 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
10 0.23 Very Good H VS1 59.4 61 338 4 4.05 2.39
# ℹ 50,089 more rows
Exercise 2: Write a function that imputes missing values for numeric variables in a dataset. The user should be able to supply the dataset, the variables to impute values for, and a function to use when imputing. Hint 1: You will need to use across()
to apply your function, since the user can input multiple variables. Hint 2: The replace_na()
function is helpful here!
<- function(data, ..., impute_fun = mean) {
impute_missing |>
data mutate(across(c(...), \(col) {
if (!is.numeric(col)) {
cli_warn(glue(
"{cur_column()} is not a numeric column and will be ignored."
))return(col)
}replace_na(col, impute_fun(col, na.rm = TRUE))
})) }
Testing Your Function!
## Testing how your function handles multiple input variables
impute_missing(nycflights13::flights, arr_delay, dep_delay)
# A tibble: 336,776 × 19
year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
<int> <int> <int> <int> <int> <dbl> <int> <int>
1 2013 1 1 517 515 2 830 819
2 2013 1 1 533 529 4 850 830
3 2013 1 1 542 540 2 923 850
4 2013 1 1 544 545 -1 1004 1022
5 2013 1 1 554 600 -6 812 837
6 2013 1 1 554 558 -4 740 728
7 2013 1 1 555 600 -5 913 854
8 2013 1 1 557 600 -3 709 723
9 2013 1 1 557 600 -3 838 846
10 2013 1 1 558 600 -2 753 745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
# tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
# hour <dbl>, minute <dbl>, time_hour <dttm>
## Testing how your function handles an input that isn't numeric
impute_missing(nycflights13::flights, arr_delay, carrier)
Warning: There was 1 warning in `mutate()`.
ℹ In argument: `across(...)`.
Caused by warning:
! carrier is not a numeric column and will be ignored.
# A tibble: 336,776 × 19
year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
<int> <int> <int> <int> <int> <dbl> <int> <int>
1 2013 1 1 517 515 2 830 819
2 2013 1 1 533 529 4 850 830
3 2013 1 1 542 540 2 923 850
4 2013 1 1 544 545 -1 1004 1022
5 2013 1 1 554 600 -6 812 837
6 2013 1 1 554 558 -4 740 728
7 2013 1 1 555 600 -5 913 854
8 2013 1 1 557 600 -3 709 723
9 2013 1 1 557 600 -3 838 846
10 2013 1 1 558 600 -2 753 745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
# tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
# hour <dbl>, minute <dbl>, time_hour <dttm>
## Testing how your function handles a non-default impute_fun
impute_missing(nycflights13::flights, arr_delay, dep_delay, impute_fun = median)
# A tibble: 336,776 × 19
year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
<int> <int> <int> <int> <int> <dbl> <int> <int>
1 2013 1 1 517 515 2 830 819
2 2013 1 1 533 529 4 850 830
3 2013 1 1 542 540 2 923 850
4 2013 1 1 544 545 -1 1004 1022
5 2013 1 1 554 600 -6 812 837
6 2013 1 1 554 558 -4 740 728
7 2013 1 1 555 600 -5 913 854
8 2013 1 1 557 600 -3 709 723
9 2013 1 1 557 600 -3 838 846
10 2013 1 1 558 600 -2 753 745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
# tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
# hour <dbl>, minute <dbl>, time_hour <dttm>
Primary Function
Exercise 3: Write a fit_model()
function that fits a specified linear regression model for a specified dataset. The function should:
- allow the user to specify if outliers should be removed (
TRUE
orFALSE
) - allow the user to specify if missing observations should be imputed (
TRUE
orFALSE
)
If either option is TRUE
, your function should call your remove_outliers()
or impute_missing()
functions to modify the data before the regression model is fit.
<- function(data, mod_formula, remove_outliers, impute_missing, ...) {
fit_model |>
data
(\(data) {if (remove_outliers) data <- remove_outliers(data, ...)
if (impute_missing) data <- impute_missing(data, ...)
data|>
})() lm(mod_formula, data = _)
}
Testing Your Function!
fit_model(
diamonds,mod_formula = price ~ carat + cut,
remove_outliers = TRUE,
impute_missing = TRUE,
price,
carat )
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2460.16 7526.96 1059.65 -410.54 295.80 82.62
Iteration
In the diamonds
dataset, we want to understand the relationship between price
and size (carat
). We want to explore variation along two choices:
The variables included in the model. We’ll explore 3 sets of variables:
- No further variables (just
price
andcarat
) - Adjusting for
cut
- Adjusting for
cut
andclarity
- Adjusting for
cut
,clarity
, andcolor
- No further variables (just
Whether or not to impute missing values
Whether or not to remove outliers in the
carat
variable (we’ll define outliers as cases whosecarat
is over 3 SDs away from the mean).
Parameters
First, we need to define the set of parameters we want to iterate the fit_model()
function over. The tidyr
package has a useful function called crossing()
that is useful for generating argument combinations. For each argument, we specify all possible values for that argument and crossing()
generates all combinations. Note that you can create a list of formula objects in R with c(y ~ x1, y ~ x1 + x2)
.
<- crossing(
df_arg_combos impute = c(TRUE, FALSE),
remove_outliers = c(TRUE, FALSE),
mod = c(y ~ x1, y ~ x1 + x2)
) df_arg_combos
Exercise 4: Use crossing()
to create the data frame of argument combinations for our analyses.
<- crossing(
model_args mod_formula = c(
~ carat,
price ~ carat + cut,
price ~ carat + cut + clarity,
price ~ carat + cut + clarity + color
price
),remove_outliers = c(TRUE, FALSE),
impute_missing = c(TRUE, FALSE),
)
Iterating Over the Parameters
We’ve arrived at the final step!
Exercise 5: Use pmap()
from purrr
to apply the fit_model()
function to every combination of arguments from diamonds
.
pmap(model_args, fit_model, data = diamonds, carat)
[[1]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat
-2256 7756
[[2]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat
-2256 7756
[[3]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat
-2354 7898
[[4]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat
-2354 7898
[[5]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2701.38 7871.08 1239.80 -528.60 367.91 74.59
[[6]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2701.38 7871.08 1239.80 -528.60 367.91 74.59
[[7]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2782.90 8012.63 1179.04 -465.31 337.40 90.91
[[8]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-2782.90 8012.63 1179.04 -465.31 337.40 90.91
[[9]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3187.540 8472.026 713.804 -334.503 188.482 1.663
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4011.681 -1821.922 917.658 -430.047 257.141 26.909
clarity^7
186.742
[[10]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3187.540 8472.026 713.804 -334.503 188.482 1.663
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4011.681 -1821.922 917.658 -430.047 257.141 26.909
clarity^7
186.742
[[11]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3211.955 8604.166 693.601 -310.001 176.420 9.009
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
3790.470 -1546.330 694.758 -294.293 178.926 56.096
clarity^7
181.933
[[12]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3211.955 8604.166 693.601 -310.001 176.420 9.009
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
3790.470 -1546.330 694.758 -294.293 178.926 56.096
clarity^7
181.933
[[13]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3710.603 8886.129 698.907 -327.686 180.565 -1.207
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4217.535 -1832.406 923.273 -361.995 216.616 2.105
clarity^7 color.L color.Q color.C color^4 color^5
110.340 -1910.288 -627.954 -171.960 21.678 -85.943
color^6
-49.986
[[14]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3710.603 8886.129 698.907 -327.686 180.565 -1.207
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
4217.535 -1832.406 923.273 -361.995 216.616 2.105
clarity^7 color.L color.Q color.C color^4 color^5
110.340 -1910.288 -627.954 -171.960 21.678 -85.943
color^6
-49.986
[[15]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3724.322 9007.286 675.568 -301.591 167.295 6.644
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
3982.636 -1541.736 689.208 -219.693 135.996 32.560
clarity^7 color.L color.Q color.C color^4 color^5
104.156 -1925.179 -613.256 -163.611 29.485 -72.600
color^6
-49.973
[[16]]
Call:
lm(formula = mod_formula, data = (function(data) {
if (remove_outliers)
data <- remove_outliers(data, ...)
if (impute_missing)
data <- impute_missing(data, ...)
data
})(data))
Coefficients:
(Intercept) carat cut.L cut.Q cut.C cut^4
-3724.322 9007.286 675.568 -301.591 167.295 6.644
clarity.L clarity.Q clarity.C clarity^4 clarity^5 clarity^6
3982.636 -1541.736 689.208 -219.693 135.996 32.560
clarity^7 color.L color.Q color.C color^4 color^5
104.156 -1925.179 -613.256 -163.611 29.485 -72.600
color^6
-49.973