Lab 6 - Spicy

Fancy Functions
Author
Affiliation

Cal Poly, San Luis Obispo

Published

May 12, 2025

Modified

June 6, 2025

This assignment will challenge your function writing abilities. I’m not going to lie, these functions are difficult but well within your reach. I do, however, want to recognize that not everyone is interested in being a “virtuoso” with their function writing. So, there are two options for this week’s lab:

Setting the Stage

My number one use case for writing functions and iteration / looping is to perform some exploration or modeling repeatedly for different “tweaked” versions. For example, our broad goal might be to fit a linear regression model to our data. However, there are often multiple choices that we have to make in practice:

  • Keep missing values or fill them in (imputation)?
  • Filter out outliers in one or more variables?

We can map these choices to arguments in a custom model-fitting function:

  • impute: TRUE or FALSE
  • remove_outliers: TRUE or FALSE

A function that implements the analysis and allows for variation in these choices:

fit_model <- function(df, impute, remove_outliers, mod) {
  if (impute) {
    df <- some_imputation_function(df)
  }

  if (remove_outliers) {
    df <- function_for_removing_outliers(df)
  }

  lm(mod, data = df)
}

Helper Functions

Exercise 1: Write a function that removes outliers in a dataset. The user should be able to supply the dataset, the variables to remove outliers from, and a threshold on the number of SDs away from the mean used to define outliers. Hint 1: You will need to calculate a z-score to filter the values! Hint 2: You might want to consider specifying a default value (e.g., 3) for sd_thresh.

library(tidyverse)
use("glue", "glue")
use("cli", "cli_warn")

remove_outliers <- function(data, ..., sd_thresh = 3) {
  data |>
    filter(
      if_all(
        c(...),
        \(col) {
          if (!is.numeric(col)) {
            # TODO: use a native function instead of `deparse(substitute())`
            cli_warn(glue(
              "{deparse(substitute(col))} is not a numeric column and will be ignored."
            ))
            return(TRUE)
          }
          abs((col - mean(col, na.rm = TRUE)) / sd(col, TRUE)) <= sd_thresh
        }
      )
    )
}

Testing Your Function!

## Testing how your function handles multiple input variables
remove_outliers(diamonds, price, x, y, z)
# A tibble: 52,689 × 10
   carat cut       color clarity depth table price     x     y     z
   <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
 1  0.23 Ideal     E     SI2      61.5    55   326  3.95  3.98  2.43
 2  0.21 Premium   E     SI1      59.8    61   326  3.89  3.84  2.31
 3  0.23 Good      E     VS1      56.9    65   327  4.05  4.07  2.31
 4  0.29 Premium   I     VS2      62.4    58   334  4.2   4.23  2.63
 5  0.31 Good      J     SI2      63.3    58   335  4.34  4.35  2.75
 6  0.24 Very Good J     VVS2     62.8    57   336  3.94  3.96  2.48
 7  0.24 Very Good I     VVS1     62.3    57   336  3.95  3.98  2.47
 8  0.26 Very Good H     SI1      61.9    55   337  4.07  4.11  2.53
 9  0.22 Fair      E     VS2      65.1    61   337  3.87  3.78  2.49
10  0.23 Very Good H     VS1      59.4    61   338  4     4.05  2.39
# ℹ 52,679 more rows
## Testing how your function handles an input that isn't numeric
remove_outliers(diamonds, price, color)
Warning: There was 1 warning in `filter()`.
ℹ In argument: `&...`.
Caused by warning:
! color is not a numeric column and will be ignored.
# A tibble: 52,734 × 10
   carat cut       color clarity depth table price     x     y     z
   <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
 1  0.23 Ideal     E     SI2      61.5    55   326  3.95  3.98  2.43
 2  0.21 Premium   E     SI1      59.8    61   326  3.89  3.84  2.31
 3  0.23 Good      E     VS1      56.9    65   327  4.05  4.07  2.31
 4  0.29 Premium   I     VS2      62.4    58   334  4.2   4.23  2.63
 5  0.31 Good      J     SI2      63.3    58   335  4.34  4.35  2.75
 6  0.24 Very Good J     VVS2     62.8    57   336  3.94  3.96  2.48
 7  0.24 Very Good I     VVS1     62.3    57   336  3.95  3.98  2.47
 8  0.26 Very Good H     SI1      61.9    55   337  4.07  4.11  2.53
 9  0.22 Fair      E     VS2      65.1    61   337  3.87  3.78  2.49
10  0.23 Very Good H     VS1      59.4    61   338  4     4.05  2.39
# ℹ 52,724 more rows
## Testing how your function handles a non-default sd_thresh
remove_outliers(diamonds, price, x, y, z, sd_thresh = 2)
# A tibble: 50,099 × 10
   carat cut       color clarity depth table price     x     y     z
   <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
 1  0.23 Ideal     E     SI2      61.5    55   326  3.95  3.98  2.43
 2  0.21 Premium   E     SI1      59.8    61   326  3.89  3.84  2.31
 3  0.23 Good      E     VS1      56.9    65   327  4.05  4.07  2.31
 4  0.29 Premium   I     VS2      62.4    58   334  4.2   4.23  2.63
 5  0.31 Good      J     SI2      63.3    58   335  4.34  4.35  2.75
 6  0.24 Very Good J     VVS2     62.8    57   336  3.94  3.96  2.48
 7  0.24 Very Good I     VVS1     62.3    57   336  3.95  3.98  2.47
 8  0.26 Very Good H     SI1      61.9    55   337  4.07  4.11  2.53
 9  0.22 Fair      E     VS2      65.1    61   337  3.87  3.78  2.49
10  0.23 Very Good H     VS1      59.4    61   338  4     4.05  2.39
# ℹ 50,089 more rows

Exercise 2: Write a function that imputes missing values for numeric variables in a dataset. The user should be able to supply the dataset, the variables to impute values for, and a function to use when imputing. Hint 1: You will need to use across() to apply your function, since the user can input multiple variables. Hint 2: The replace_na() function is helpful here!

impute_missing <- function(data, ..., impute_fun = mean) {
  data |>
    mutate(across(c(...), \(col) {
      if (!is.numeric(col)) {
        cli_warn(glue(
          "{cur_column()} is not a numeric column and will be ignored."
        ))
        return(col)
      }
      replace_na(col, impute_fun(col, na.rm = TRUE))
    }))
}

Testing Your Function!

## Testing how your function handles multiple input variables
impute_missing(nycflights13::flights, arr_delay, dep_delay)
# A tibble: 336,776 × 19
    year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
   <int> <int> <int>    <int>          <int>     <dbl>    <int>          <int>
 1  2013     1     1      517            515         2      830            819
 2  2013     1     1      533            529         4      850            830
 3  2013     1     1      542            540         2      923            850
 4  2013     1     1      544            545        -1     1004           1022
 5  2013     1     1      554            600        -6      812            837
 6  2013     1     1      554            558        -4      740            728
 7  2013     1     1      555            600        -5      913            854
 8  2013     1     1      557            600        -3      709            723
 9  2013     1     1      557            600        -3      838            846
10  2013     1     1      558            600        -2      753            745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
#   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#   hour <dbl>, minute <dbl>, time_hour <dttm>
## Testing how your function handles an input that isn't numeric
impute_missing(nycflights13::flights, arr_delay, carrier)
Warning: There was 1 warning in `mutate()`.
ℹ In argument: `across(...)`.
Caused by warning:
! carrier is not a numeric column and will be ignored.
# A tibble: 336,776 × 19
    year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
   <int> <int> <int>    <int>          <int>     <dbl>    <int>          <int>
 1  2013     1     1      517            515         2      830            819
 2  2013     1     1      533            529         4      850            830
 3  2013     1     1      542            540         2      923            850
 4  2013     1     1      544            545        -1     1004           1022
 5  2013     1     1      554            600        -6      812            837
 6  2013     1     1      554            558        -4      740            728
 7  2013     1     1      555            600        -5      913            854
 8  2013     1     1      557            600        -3      709            723
 9  2013     1     1      557            600        -3      838            846
10  2013     1     1      558            600        -2      753            745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
#   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#   hour <dbl>, minute <dbl>, time_hour <dttm>
## Testing how your function handles a non-default impute_fun
impute_missing(nycflights13::flights, arr_delay, dep_delay, impute_fun = median)
# A tibble: 336,776 × 19
    year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
   <int> <int> <int>    <int>          <int>     <dbl>    <int>          <int>
 1  2013     1     1      517            515         2      830            819
 2  2013     1     1      533            529         4      850            830
 3  2013     1     1      542            540         2      923            850
 4  2013     1     1      544            545        -1     1004           1022
 5  2013     1     1      554            600        -6      812            837
 6  2013     1     1      554            558        -4      740            728
 7  2013     1     1      555            600        -5      913            854
 8  2013     1     1      557            600        -3      709            723
 9  2013     1     1      557            600        -3      838            846
10  2013     1     1      558            600        -2      753            745
# ℹ 336,766 more rows
# ℹ 11 more variables: arr_delay <dbl>, carrier <chr>, flight <int>,
#   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#   hour <dbl>, minute <dbl>, time_hour <dttm>

Primary Function

Exercise 3: Write a fit_model() function that fits a specified linear regression model for a specified dataset. The function should:

  • allow the user to specify if outliers should be removed (TRUE or FALSE)
  • allow the user to specify if missing observations should be imputed (TRUE or FALSE)

If either option is TRUE, your function should call your remove_outliers() or impute_missing() functions to modify the data before the regression model is fit.

fit_model <- function(data, mod_formula, remove_outliers, impute_missing, ...) {
  data |>
    (\(data) {
      if (remove_outliers) data <- remove_outliers(data, ...)
      if (impute_missing) data <- impute_missing(data, ...)
      data
    })() |>
    lm(mod_formula, data = _)
}

Testing Your Function!

fit_model(
  diamonds,
  mod_formula = price ~ carat + cut,
  remove_outliers = TRUE,
  impute_missing = TRUE,
  price,
  carat
)

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2460.16      7526.96      1059.65      -410.54       295.80        82.62  

Iteration

In the diamonds dataset, we want to understand the relationship between price and size (carat). We want to explore variation along two choices:

  1. The variables included in the model. We’ll explore 3 sets of variables:

    • No further variables (just price and carat)
    • Adjusting for cut
    • Adjusting for cut and clarity
    • Adjusting for cut, clarity, and color
  2. Whether or not to impute missing values

  3. Whether or not to remove outliers in the carat variable (we’ll define outliers as cases whose carat is over 3 SDs away from the mean).

Parameters

First, we need to define the set of parameters we want to iterate the fit_model() function over. The tidyr package has a useful function called crossing() that is useful for generating argument combinations. For each argument, we specify all possible values for that argument and crossing() generates all combinations. Note that you can create a list of formula objects in R with c(y ~ x1, y ~ x1 + x2).

df_arg_combos <- crossing(
  impute = c(TRUE, FALSE),
  remove_outliers = c(TRUE, FALSE),
  mod = c(y ~ x1, y ~ x1 + x2)
)
df_arg_combos

Exercise 4: Use crossing() to create the data frame of argument combinations for our analyses.

model_args <- crossing(
  mod_formula = c(
    price ~ carat,
    price ~ carat + cut,
    price ~ carat + cut + clarity,
    price ~ carat + cut + clarity + color
  ),
  remove_outliers = c(TRUE, FALSE),
  impute_missing = c(TRUE, FALSE),
)

Iterating Over the Parameters

We’ve arrived at the final step!

Exercise 5: Use pmap() from purrr to apply the fit_model() function to every combination of arguments from diamonds.

pmap(model_args, fit_model, data = diamonds, carat)
[[1]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat  
      -2256         7756  


[[2]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat  
      -2256         7756  


[[3]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat  
      -2354         7898  


[[4]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat  
      -2354         7898  


[[5]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2701.38      7871.08      1239.80      -528.60       367.91        74.59  


[[6]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2701.38      7871.08      1239.80      -528.60       367.91        74.59  


[[7]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2782.90      8012.63      1179.04      -465.31       337.40        90.91  


[[8]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
   -2782.90      8012.63      1179.04      -465.31       337.40        90.91  


[[9]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3187.540     8472.026      713.804     -334.503      188.482        1.663  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   4011.681    -1821.922      917.658     -430.047      257.141       26.909  
  clarity^7  
    186.742  


[[10]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3187.540     8472.026      713.804     -334.503      188.482        1.663  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   4011.681    -1821.922      917.658     -430.047      257.141       26.909  
  clarity^7  
    186.742  


[[11]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3211.955     8604.166      693.601     -310.001      176.420        9.009  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   3790.470    -1546.330      694.758     -294.293      178.926       56.096  
  clarity^7  
    181.933  


[[12]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3211.955     8604.166      693.601     -310.001      176.420        9.009  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   3790.470    -1546.330      694.758     -294.293      178.926       56.096  
  clarity^7  
    181.933  


[[13]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3710.603     8886.129      698.907     -327.686      180.565       -1.207  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   4217.535    -1832.406      923.273     -361.995      216.616        2.105  
  clarity^7      color.L      color.Q      color.C      color^4      color^5  
    110.340    -1910.288     -627.954     -171.960       21.678      -85.943  
    color^6  
    -49.986  


[[14]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3710.603     8886.129      698.907     -327.686      180.565       -1.207  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   4217.535    -1832.406      923.273     -361.995      216.616        2.105  
  clarity^7      color.L      color.Q      color.C      color^4      color^5  
    110.340    -1910.288     -627.954     -171.960       21.678      -85.943  
    color^6  
    -49.986  


[[15]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3724.322     9007.286      675.568     -301.591      167.295        6.644  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   3982.636    -1541.736      689.208     -219.693      135.996       32.560  
  clarity^7      color.L      color.Q      color.C      color^4      color^5  
    104.156    -1925.179     -613.256     -163.611       29.485      -72.600  
    color^6  
    -49.973  


[[16]]

Call:
lm(formula = mod_formula, data = (function(data) {
    if (remove_outliers) 
        data <- remove_outliers(data, ...)
    if (impute_missing) 
        data <- impute_missing(data, ...)
    data
})(data))

Coefficients:
(Intercept)        carat        cut.L        cut.Q        cut.C        cut^4  
  -3724.322     9007.286      675.568     -301.591      167.295        6.644  
  clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5    clarity^6  
   3982.636    -1541.736      689.208     -219.693      135.996       32.560  
  clarity^7      color.L      color.Q      color.C      color^4      color^5  
    104.156    -1925.179     -613.256     -163.611       29.485      -72.600  
    color^6  
    -49.973