Mental Health and Musical Preferences

Author

Meghan Harris

Introduction

I haven’t done statistical modeling since 2021, and back then I tried using base R’s lm() and hated it. I didn’t know tidymodels existed at the time, so I assumed modeling just wasn’t for me. Fast forward to 2024: I’m pivoting from data science to open source development, learning Python alongside R, and decided to revisit modeling with modern tools. This notebook compares a simple ML workflow in both R (tidymodels) and Python (scikit-learn) to:

  1. Refresh my modeling skills using tidymodels (which actually makes sense to my brain)
  2. Practice Python by replicating a workflow I understand in R

I’m using the Music & Mental Health Survey Results dataset from Kaggle to predict anxiety scores from music listening habits.

Data Import and Viewing

First, we import the data.

Code
# install.packages(c("janitor", "tidyverse", "tidymodels"))
library(tidymodels)
library(tidyverse)
library(reticulate)

# Grab relative data path
data_path <-
  here::here(
    "mental_health_music",
    "data",
    "mxmh_survey_results.csv"
  )

# Import data and clean vars
df_mh_music_r <-
  read_csv(data_path) |>
  janitor::clean_names()
Code
import polars as pl
import janitor.polars
from pathlib import Path
from plotnine import (
    ggplot,
    aes,
    after_stat,
    element_blank,
    element_text,
    geom_abline,
    geom_bar,
    geom_point,
    geom_text,
    labs,
    scale_fill_gradient2,
    scale_x_discrete,
    theme,
    theme_minimal,
)
import textwrap

# Python has been SUPER annoying with paths for me in quarto docs
# Maybe it's me/my machine but, meh
# Setting two possible locations for the data file
possible_paths = [
    Path("data") / "mxmh_survey_results.csv",
    Path("mental_health_music") / "data" / "mxmh_survey_results.csv",
]

for path in possible_paths:
    if path.exists():
        data_path = path
        break
else:
    raise FileNotFoundError("mxmh_survey_results.csv not found in expected locations.")

# Read in csv and clean up the names
df_mh_music_py = pl.read_csv(
    data_path,
    # This is needed because there's decimals in the
    # scoring cols where integer was inferred
    infer_schema_length=1000,
).clean_names(remove_special=True)

Then, do a little skimming with glimpse…

Code
# Take a peek
df_mh_music_r |>
  glimpse()
Rows: 736
Columns: 33
$ timestamp                  <chr> "8/27/2022 19:29:02", "8/27/2022 19:57:31",…
$ age                        <dbl> 18, 63, 18, 61, 18, 18, 18, 21, 19, 18, 18,…
$ primary_streaming_service  <chr> "Spotify", "Pandora", "Spotify", "YouTube M…
$ hours_per_day              <dbl> 3.0, 1.5, 4.0, 2.5, 4.0, 5.0, 3.0, 1.0, 6.0…
$ while_working              <chr> "Yes", "Yes", "No", "Yes", "Yes", "Yes", "Y…
$ instrumentalist            <chr> "Yes", "No", "No", "No", "No", "Yes", "Yes"…
$ composer                   <chr> "Yes", "No", "No", "Yes", "No", "Yes", "No"…
$ fav_genre                  <chr> "Latin", "Rock", "Video game music", "Jazz"…
$ exploratory                <chr> "Yes", "Yes", "No", "Yes", "Yes", "Yes", "Y…
$ foreign_languages          <chr> "Yes", "No", "Yes", "Yes", "No", "Yes", "Ye…
$ bpm                        <dbl> 156, 119, 132, 84, 107, 86, 66, 95, 94, 155…
$ frequency_classical        <chr> "Rarely", "Sometimes", "Never", "Sometimes"…
$ frequency_country          <chr> "Never", "Never", "Never", "Never", "Never"…
$ frequency_edm              <chr> "Rarely", "Never", "Very frequently", "Neve…
$ frequency_folk             <chr> "Never", "Rarely", "Never", "Rarely", "Neve…
$ frequency_gospel           <chr> "Never", "Sometimes", "Never", "Sometimes",…
$ frequency_hip_hop          <chr> "Sometimes", "Rarely", "Rarely", "Never", "…
$ frequency_jazz             <chr> "Never", "Very frequently", "Rarely", "Very…
$ frequency_k_pop            <chr> "Very frequently", "Rarely", "Very frequent…
$ frequency_latin            <chr> "Very frequently", "Sometimes", "Never", "V…
$ frequency_lofi             <chr> "Rarely", "Rarely", "Sometimes", "Sometimes…
$ frequency_metal            <chr> "Never", "Never", "Sometimes", "Never", "Ne…
$ frequency_pop              <chr> "Very frequently", "Sometimes", "Rarely", "…
$ frequency_r_b              <chr> "Sometimes", "Sometimes", "Never", "Sometim…
$ frequency_rap              <chr> "Very frequently", "Rarely", "Rarely", "Nev…
$ frequency_rock             <chr> "Never", "Very frequently", "Rarely", "Neve…
$ frequency_video_game_music <chr> "Sometimes", "Rarely", "Very frequently", "…
$ anxiety                    <dbl> 3, 7, 7, 9, 7, 8, 4, 5, 2, 2, 7, 1, 9, 2, 6…
$ depression                 <dbl> 0, 2, 7, 7, 2, 8, 8, 3, 0, 2, 7, 0, 3, 1, 4…
$ insomnia                   <dbl> 1, 2, 10, 3, 5, 7, 6, 5, 0, 5, 4, 0, 2, 2, …
$ ocd                        <dbl> 0, 1, 2, 3, 9, 7, 0, 3, 0, 1, 7, 1, 7, 0, 0…
$ music_effects              <chr> NA, NA, "No effect", "Improve", "Improve", …
$ permissions                <chr> "I understand.", "I understand.", "I unders…
Code
# | output-fold: true
# Take a peek
df_mh_music_py.glimpse()
Rows: 736
Columns: 33
$ timestamp                  <str> '8/27/2022 19:29:02', '8/27/2022 19:57:31', '8/27/2022 21:28:18', '8/27/2022 21:40:40', '8/27/2022 21:54:47', '8/27/2022 21:56:50', '8/27/2022 22:00:29', '8/27/2022 22:18:59', '8/27/2022 22:33:05', '8/27/2022 22:44:03'
$ age                        <i64> 18, 63, 18, 61, 18, 18, 18, 21, 19, 18
$ primary_streaming_service  <str> 'Spotify', 'Pandora', 'Spotify', 'YouTube Music', 'Spotify', 'Spotify', 'YouTube Music', 'Spotify', 'Spotify', 'I do not use a streaming service.'
$ hours_per_day              <f64> 3.0, 1.5, 4.0, 2.5, 4.0, 5.0, 3.0, 1.0, 6.0, 1.0
$ while_working              <str> 'Yes', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes'
$ instrumentalist            <str> 'Yes', 'No', 'No', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No'
$ composer                   <str> 'Yes', 'No', 'No', 'Yes', 'No', 'Yes', 'No', 'No', 'No', 'No'
$ fav_genre                  <str> 'Latin', 'Rock', 'Video game music', 'Jazz', 'R&B', 'Jazz', 'Video game music', 'K pop', 'Rock', 'R&B'
$ exploratory                <str> 'Yes', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes'
$ foreign_languages          <str> 'Yes', 'No', 'Yes', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes'
$ bpm                        <i64> 156, 119, 132, 84, 107, 86, 66, 95, 94, 155
$ frequency_classical        <str> 'Rarely', 'Sometimes', 'Never', 'Sometimes', 'Never', 'Rarely', 'Sometimes', 'Never', 'Never', 'Rarely'
$ frequency_country          <str> 'Never', 'Never', 'Never', 'Never', 'Never', 'Sometimes', 'Never', 'Never', 'Very frequently', 'Rarely'
$ frequency_edm              <str> 'Rarely', 'Never', 'Very frequently', 'Never', 'Rarely', 'Never', 'Rarely', 'Rarely', 'Never', 'Rarely'
$ frequency_folk             <str> 'Never', 'Rarely', 'Never', 'Rarely', 'Never', 'Never', 'Sometimes', 'Never', 'Sometimes', 'Rarely'
$ frequency_gospel           <str> 'Never', 'Sometimes', 'Never', 'Sometimes', 'Rarely', 'Never', 'Rarely', 'Never', 'Never', 'Sometimes'
$ frequency_hip_hop          <str> 'Sometimes', 'Rarely', 'Rarely', 'Never', 'Very frequently', 'Sometimes', 'Rarely', 'Very frequently', 'Never', 'Rarely'
$ frequency_jazz             <str> 'Never', 'Very frequently', 'Rarely', 'Very frequently', 'Never', 'Very frequently', 'Sometimes', 'Rarely', 'Never', 'Rarely'
$ frequency_k_pop            <str> 'Very frequently', 'Rarely', 'Very frequently', 'Sometimes', 'Very frequently', 'Very frequently', 'Never', 'Very frequently', 'Never', 'Never'
$ frequency_latin            <str> 'Very frequently', 'Sometimes', 'Never', 'Very frequently', 'Sometimes', 'Rarely', 'Rarely', 'Never', 'Never', 'Rarely'
$ frequency_lofi             <str> 'Rarely', 'Rarely', 'Sometimes', 'Sometimes', 'Sometimes', 'Very frequently', 'Rarely', 'Sometimes', 'Never', 'Rarely'
$ frequency_metal            <str> 'Never', 'Never', 'Sometimes', 'Never', 'Never', 'Rarely', 'Rarely', 'Never', 'Very frequently', 'Never'
$ frequency_pop              <str> 'Very frequently', 'Sometimes', 'Rarely', 'Sometimes', 'Sometimes', 'Very frequently', 'Rarely', 'Sometimes', 'Never', 'Sometimes'
$ frequency_rb               <str> 'Sometimes', 'Sometimes', 'Never', 'Sometimes', 'Very frequently', 'Very frequently', 'Rarely', 'Sometimes', 'Never', 'Sometimes'
$ frequency_rap              <str> 'Very frequently', 'Rarely', 'Rarely', 'Never', 'Very frequently', 'Very frequently', 'Never', 'Rarely', 'Never', 'Rarely'
$ frequency_rock             <str> 'Never', 'Very frequently', 'Rarely', 'Never', 'Never', 'Very frequently', 'Never', 'Never', 'Very frequently', 'Sometimes'
$ frequency_video_game_music <str> 'Sometimes', 'Rarely', 'Very frequently', 'Never', 'Rarely', 'Never', 'Sometimes', 'Rarely', 'Never', 'Sometimes'
$ anxiety                    <f64> 3.0, 7.0, 7.0, 9.0, 7.0, 8.0, 4.0, 5.0, 2.0, 2.0
$ depression                 <f64> 0.0, 2.0, 7.0, 7.0, 2.0, 8.0, 8.0, 3.0, 0.0, 2.0
$ insomnia                   <f64> 1.0, 2.0, 10.0, 3.0, 5.0, 7.0, 6.0, 5.0, 0.0, 5.0
$ ocd                        <f64> 0.0, 1.0, 2.0, 3.0, 9.0, 7.0, 0.0, 3.0, 0.0, 1.0
$ music_effects              <str> null, null, 'No effect', 'Improve', 'Improve', 'Improve', 'Improve', 'Improve', 'Improve', 'Improve'
$ permissions                <str> 'I understand.', 'I understand.', 'I understand.', 'I understand.', 'I understand.', 'I understand.', 'I understand.', 'I understand.', 'I understand.', 'I understand.'

Data Exploration

Looking at the data, I’m interested in predicting “anxiety score” from musical listening patterns with respect to genre frequencies. We’ll need the anxiety and the frequency_*, so let’s visualize what we have here.

For the anxiety score interpretation on a scale from 0 to 10, respondents where asked “how often do you experience this?”:

  • 0 - I do not experience this.

  • 10 - I experience this regularly, constantly/or to an extreme.

So, of no surprise to me, this data is right-skewed with a concentration sitting around a score from 7 - 8. The density curve also shows limited variability at lower anxiety levels (“flattening” of the curve), which may affect model performance in predicting across the full range of the outcome.

Code
df_mh_music_r |>
  ggplot(aes(x = anxiety)) +
  geom_histogram(
    aes(y = after_stat(density), fill = after_stat(count)),
    bins = 11,
    color = "#000000"
  ) +
  geom_density(color = "#000000", linewidth = 3) +
  geom_density(color = "#60bcde", linewidth = 1) +
  scale_fill_gradient2(
    low = "#ffffff",
    high = "#084164"
  ) +
  labs(
    title = "Distribution of Self-Reported Anxiety Scores",
    subtitle = "Among Survey Respondents",
    fill = "Total Responses:",
    x = "Reported Anxiety Score",
    y = "Density"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(size = 14, face = "bold"),
    text = element_text(size = 10),
    legend.position = "top"
  )

I also should look at the music genre categories as there are 16 different genres that can each have one of four different responses: “Never”, “Rarely”, “Sometimes”, or “Very frequently”

Code
# Need to pivot the genres
df_genre_pivot_r <-
  df_mh_music_r |>
  pivot_longer(
    cols = starts_with("frequency")
  ) |>
  mutate(
    name_cleaned = str_remove(name, "^frequency_"),
    name_cleaned = case_when(
      name_cleaned == "k_pop" ~ "K-pop",
      name_cleaned == "r_b" ~ "r & b",
      name_cleaned == "hip_hop" ~ "hip-hop",
      name_cleaned == "edm" ~ "EDM",
      .default = str_replace_all(name_cleaned, "_", " ")
    ),
    name_cleaned = if_else(
      name_cleaned == "EDM",
      name_cleaned,
      str_to_title(name_cleaned)
    )
  )

df_genre_pivot_r |>
  ggplot(aes(x = name_cleaned, fill = value)) +
  geom_bar(
    color = "#000000"
  ) +
  scale_fill_manual(
    values = c(
      "Never" = "#FFFFFF",
      "Rarely" = "#ACBFCB",
      "Sometimes" = "#5A8097",
      "Very frequently" = "#084164"
    )
  ) +
  labs(
    title = "Distribution of Music Genre Listening Frequency",
    subtitle = "Among Survey Respondents",
    fill = "Response:",
    x = "Musical Genre",
    y = "Count"
  ) +
  theme_minimal() +
  theme(
    panel.grid.major.x = element_line(color = "#000000"),
    plot.title = element_text(size = 14, face = "bold"),
    text = element_text(size = 10),
    axis.text.x.bottom = element_text(
      size = 10,
      angle = 45,
      vjust = 1,
      hjust = 1
    ),
    legend.position = "top"
  )

So a lot of responses have a value of “Never” or “Rarely”. Have 16 different genres with each having one of four possible responses means that the feature space is way too big because that’s a possibility of 4^`16 combinations (4,294,967,296) !!!

I’m going to make a binary response variable. Either “Yes” or “No”. Where a value of “Very frequently” or “Sometimes” will be “Yes”, and the others, “No”. I’m also going to see which of these genres are listened to by at least 35% of the survey respondents. I’ll mark this with a dotted black line on the graph.

Code
df_genre_pivot_r |>
  mutate(
    value_collapsed = case_when(
      value %in% c("Never", "Rarely") ~ "No",
      value %in% c("Sometimes", "Very frequently") ~ "Yes",
    ),
    name_cleaned_ordered = fct_reorder(
      name_cleaned,
      value_collapsed == "Yes",
      .fun = sum,
      .desc = TRUE
    )
  ) |>
  ggplot(aes(x = name_cleaned_ordered, fill = value_collapsed)) +
  geom_bar(
    color = "#000000"
  ) +
  geom_hline(
    yintercept = 736 * .35,
    color = "#000000",
    linewidth = 1.5,
    linetype = 2
  ) +
  scale_fill_manual(
    values = c(
      "No" = "#FFFFFF",
      "Yes" = "#084164"
    )
  ) +
  labs(
    title = "Distribution of Music Genre Listening Frequency",
    subtitle = "Among Survey Respondents",
    fill = "Response:",
    x = "Musical Genre",
    y = "Count"
  ) +
  theme_minimal() +
  theme(
    panel.grid.major.x = element_line(color = "#000000"),
    plot.title = element_text(size = 14, face = "bold"),
    text = element_text(size = 10),
    axis.text.x.bottom = element_text(
      size = 10,
      angle = 45,
      vjust = 1,
      hjust = 1
    ),
    legend.position = "top"
  )

So the genres that are listened to by at least 35% of the respondents are Rock, Pop, Hip-Hop, Rap, Classical, Video Game Music, R & B, and Metal. Not bad at all. I’ll drop the responses for the other genres and only focus on these eight most listened to genres.

I’m genuinely curious to see how ages play out with this. Which generation has the most anxiety on average? This data was first uploaded onto Kaggle three years ago; So I’ll use the year of 2022 to create categories of generations like “Millennials”, “Gen-Z”, etc. We’ll calculate it based on this random table Google generated:


Generation Name Birth Years Current Age Range (as of 2022)
The Silent Generation 1928–1945 77–94 years old
Baby Boomers 1946–1964 58–76 years old
Generation X (Gen X) 1965–1980 42–57 years old
Millennials (Gen Y) 1981–1996 26–41 years old
Generation Z (Gen Z) 1997–2012 10–25 years old


Code
# Set the labels and wrap them for the plot
generation_labels = [
    "The Silent Generation",
    "Baby Boomers",
    "Gen-X",
    "Millennials",
    "Gen-Z",
]

# List comps are my favorite thing about Python so far
generation_labels_wrapped = [
    textwrap.fill(label, width=10) if len(label) > 12 else label
    for label in generation_labels
]

# Cant't calc birth year without age, so drop nulls
# then derive birth year
df_mh_music_py_eda = (
    df_mh_music_py.with_columns((2022 - pl.col("age")).alias("birth_year"))
    .with_columns(
        pl.when(pl.col("birth_year").is_between(1927, 1945, closed="both"))
        .then(pl.lit(generation_labels_wrapped[0]))
        .when(pl.col("birth_year").is_between(1946, 1964, closed="both"))
        .then(pl.lit(generation_labels_wrapped[1]))
        .when(pl.col("birth_year").is_between(1965, 1980, closed="both"))
        .then(pl.lit(generation_labels_wrapped[2]))
        .when(pl.col("birth_year").is_between(1981, 1996, closed="both"))
        .then(pl.lit(generation_labels_wrapped[3]))
        .otherwise(pl.lit(generation_labels_wrapped[4]))
        .alias("generation")
    )
    .group_by("generation")
    .agg(pl.col("anxiety").mean().alias("mean_anxiety"), pl.len().alias("n_generation"))
)

df_mh_music_py_eda = df_mh_music_py_eda.with_columns(
    pl.col("mean_anxiety")
    .map_elements(lambda x: f"{x:.2f}" if x is not None else None)
    .alias("top_label"),
    ("(n = " + pl.col("n_generation").cast(str) + ")").alias("bottom_label"),
)

(
    ggplot(df_mh_music_py_eda, aes(x="generation", y="mean_anxiety", fill="mean_anxiety"))
    + geom_bar(stat="identity", color="#000000")
    + geom_text(
        label=df_mh_music_py_eda["top_label"], nudge_y=-0.5, size=13, fontweight="bold"
    )
    + geom_text(label=df_mh_music_py_eda["bottom_label"], nudge_y=-1.1, fontstyle="italic")
    + scale_x_discrete(limits=generation_labels_wrapped)
    + scale_fill_gradient2(low="#f0cf59", high="#822801")
    + theme_minimal()
    + theme(
        plot_title=element_text(size=14, face="bold"),
        text=element_text(size=10),
        legend_position="top",
        axis_text_y=element_blank(),
        axis_title_y=element_blank(),
        axis_title_x=element_blank(),
        axis_text_x=element_text(size=10, face="demibold", color="#000000"),
    )
    + labs(
        title="Mean Self-Reported Anxiety Score",
        subtitle="Among Generational Groups",
        fill="Average Anxiety Score:",
    )
)


To absolutely no one’s surprise, millennials and Gen-Z have the highest reported scores of anxiety. But also, as I expected, major selection bias is going on here as the sample sizes among the generations are super imbalanced. Do we really have 2 respondents from the silent generation? Sounds sus. I expected 0. It’s also interesting that Millennials and Gen-Z are almost identical. We’re just doomed I guess 🫠

Data Cleaning & Preprocessing

Now, we’ll just grab the columns we need and make sure there’s no missing values or anything like that. In R, we can make our frequency_* variables into factors as well. As mentioned earlier, there are four potential responses that a genre can have to answer to the question “How often do you listen to * genre?”

Code
genres_to_keep <-
  paste0(
    "frequency_",
    c(
      "rock",
      "pop",
      "hip_hop",
      "rap",
      "classical",
      "video_game_music",
      "r_b",
      "metal"
    )
  )

df_mh_music_r_sub <-
  df_mh_music_r |>
  # Only select what we need
  select(anxiety, genres_to_keep) |>
  drop_na() |>
  mutate(
    across(
      starts_with("Frequency"),
      \(x) {
        factor(
          x,
          levels = c("Never", "Rarely", "Sometimes", "Very frequently")
        )
      }
    )
  )

In python, there isn’t a direct “factor” equivalent like in R. We can make an equivalent while processing the data after we split it in the nest step. For now, I’ll just select the relevant columns and make the frequency columns are casted as categorical just to be explicit.

Code
import polars as pl

# Select the relevant columns
freq_cols = ["frequency_" + prefix for prefix in ["rock",
    "pop",
    "hip_hop",
    "rap",
    "classical",
    "video_game_music",
    "rb",
    "metal",
]]

df_mh_music_py_sub = df_mh_music_py.select(["anxiety"] + freq_cols).drop_nulls()

# Casting frequency columns to categorical
df_mh_music_py_sub = df_mh_music_py_sub.with_columns(
    [pl.col(col).cast(pl.Categorical) for col in freq_cols]
)

Train/Test Split

From this point on, I’m going to try and mirror tidymodels as closely as I can based on scikit-learn’s current APIs.

Since this will be a basic analysis, I’ll do a simple test/training split while keeping {rsample}’s default proportion of 75%. Meaning we’ll retain 75% of the data for training and test the remaining 25%. I’ll set a seed for reproducibility.

Code
set.seed(901622)

# Make a split w/ rsample
mh_music_r_split <-
  df_mh_music_r_sub |>
  initial_split()

df_music_r_testing <-
  mh_music_r_split |>
  testing()

df_music_r_training <-
  mh_music_r_split |>
  training()

We’ll also do a 75% test/train split in sci-kit learn. The sklearn.model_selection submodule has train_test_split which seems to be the closest equivalent to {rasmple}’s initial_split.

Code
from sklearn.model_selection import train_test_split

# Features (frequency columns)
feats = df_mh_music_py_sub.drop("anxiety")

# Target (Anxiety score)
target = df_mh_music_py_sub.select("anxiety")

# 75% train 25% test split four dfs in total. test and trains for both features and target
df_feats_train, df_feats_test, df_target_train, df_target_test = train_test_split(
    feats, target, test_size=0.25, random_state=901622
)

Model Building

Recipe Definition

Next, I’ll make a recipe for the training set that will convert the frequency columns into dummies. I’m also noting that given the model I’m choosing in the next step, this probably isn’t required in {parsnip}, but I wanted to make the comparison of this side-by-side with python.

Code
mh_music_recipe <-
  recipe(anxiety ~ ., df_music_r_training) |>
  step_dummy(all_nominal_predictors())

Scikit-learn doesn’t have the “recipe” framework we can use in tidymodels, but it does have ColumnTransformer in its sklearn.compose submodule that will allow us to create the same post-processing steps.

Code
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer

frequency_categories = ["Never", "Rarely", "Sometimes", "Very frequently"]

frequency_transformer = ColumnTransformer(
    transformers=[
        (
            "freq_onehot",
            OneHotEncoder(
                categories=[frequency_categories] * len(freq_cols),
                handle_unknown="ignore",
                sparse_output=False,  # <-- use this instead of sparse=False
            ),
            freq_cols,
        )
    ],
    remainder="passthrough",
    verbose_feature_names_out=False,
)

# Convert back to polars
frequency_transformer.set_output(transform="polars")
ColumnTransformer(remainder='passthrough',
                  transformers=[('freq_onehot',
                                 OneHotEncoder(categories=[['Never', 'Rarely',
                                                            'Sometimes',
                                                            'Very frequently'],
                                                           ['Never', 'Rarely',
                                                            'Sometimes',
                                                            'Very frequently'],
                                                           ['Never', 'Rarely',
                                                            'Sometimes',
                                                            'Very frequently'],
                                                           ['Never', 'Rarely',
                                                            'Sometimes',
                                                            'Very frequently'],
                                                           ['Never', 'Rarely',
                                                            'Sometimes',
                                                            'Very frequently'],
                                                           ['Never', 'Rar...
                                                            'Sometimes',
                                                            'Very frequently'],
                                                           ['Never', 'Rarely',
                                                            'Sometimes',
                                                            'Very frequently'],
                                                           ['Never', 'Rarely',
                                                            'Sometimes',
                                                            'Very frequently']],
                                               handle_unknown='ignore',
                                               sparse_output=False),
                                 ['frequency_rock', 'frequency_pop',
                                  'frequency_hip_hop', 'frequency_rap',
                                  'frequency_classical',
                                  'frequency_video_game_music', 'frequency_rb',
                                  'frequency_metal'])],
                  verbose_feature_names_out=False)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Code

# Fit on training set
df_feats_train_transformed = frequency_transformer.fit_transform(df_feats_train)

# Transform test set
df_feats_test_transformed = frequency_transformer.transform(df_feats_test)

Model Specification

I’m not trying to get too into the weeds here, but I’ll do simple tuning for cost_complexity, tree_depth, and min_n. I’ll use the “rpart” engine for proper recursion through the trees. Lastly, because our target is an anxiety score, we’ll set the model to “regression”.

Code
model_tree_r <-
  decision_tree(
    cost_complexity = tune(),
    tree_depth = tune(),
    min_n = tune()
  ) |>
  set_engine("rpart") |>
  set_mode("regression")

In Python, the process is similar. Using the tree module and DecisionTreeRegressor class. The training features and target go in the fit class constructor, and the testing features go into the predict class constructor.

Code
from sklearn import tree
model_tree_py = tree.DecisionTreeRegressor(random_state=901622, max_depth=5)

Workflow Creation

Code
workflow_r <-
  workflow() |>
  add_recipe(mh_music_recipe) |>
  add_model(model_tree_r)
Code
grid_search = grid_search.fit(df_feats_train_transformed, df_target_train)
NameError: name 'grid_search' is not defined
Code
final_model_py = grid_search.best_estimator_
NameError: name 'grid_search' is not defined
Code
workflow_py = final_model_py.predict(df_feats_test_transformed)
NameError: name 'final_model_py' is not defined

Model Tuning

Cross-Validation Setup

Code
set.seed(901622)
df_music_folds_r <- vfold_cv(df_music_r_training, v = 5)
Code
from sklearn.model_selection import cross_val_score

music_folds_py = cross_val_score(
  model_tree_py,
  df_feats_train_transformed,
  df_target_train,
  cv = 5
  )

Tuning Grid & Selection

Code
df_tree_grid_r <- grid_regular(
  cost_complexity(range = c(-10, -3), trans = scales::log10_trans()),
  tree_depth(range = c(2, 10)),
  min_n(range = c(2, 20)),
  levels = 5
)

df_mh_music_tune_results_r <-
  workflow_r |>
  tune_grid(
    resamples = df_music_folds_r,
    grid = df_tree_grid_r,
    metrics = metric_set(rmse, rsq)
  )

show_best(df_mh_music_tune_results_r, metric = "rmse")
# A tibble: 5 × 9
  cost_complexity tree_depth min_n .metric .estimator  mean     n std_err
            <dbl>      <int> <int> <chr>   <chr>      <dbl> <int>   <dbl>
1    0.0000000001          2     2 rmse    standard    2.89     5  0.0868
2    0.0000000001          2     6 rmse    standard    2.89     5  0.0868
3    0.0000000001          2    11 rmse    standard    2.89     5  0.0868
4    0.0000000001          2    15 rmse    standard    2.89     5  0.0868
5    0.0000000001          2    20 rmse    standard    2.89     5  0.0868
# ℹ 1 more variable: .config <chr>
Code
best_params_r <- select_best(df_mh_music_tune_results_r, metric = "rmse")

When tuning these results, I actually got a warning:

→ A | warning: A correlation computation is required,
but `estimate` is constant and has 0 standard deviation,
resulting in a divide by 0 error. `NA` will be returned.
There were issues with some computations   A: x5

Looking at my best models above, I can verify that these warnings were not referring to my best models here. So I’ll continue on with the best one I have:

Code
best_params_r
# A tibble: 1 × 4
  cost_complexity tree_depth min_n .config          
            <dbl>      <int> <int> <chr>            
1    0.0000000001          2     2 pre0_mod001_post0
Code
from sklearn.model_selection import GridSearchCV

param_grid = {
    "ccp_alpha": [1.000000e-10, 5.623413e-09, 3.162278e-07, 1.778279e-05, 1.000000e-03],
    "max_depth": [2, 4, 6, 8, 10],  # matches tree_depth in R
    "min_samples_leaf": [2, 6, 10, 14, 18],  # matches min_n in R
}

grid_search = GridSearchCV(estimator=model_tree_py, param_grid=param_grid, cv=5)

grid_search = grid_search.fit(df_feats_train_transformed, df_target_train)
final_model_py = grid_search.best_estimator_
y_pred = final_model_py.predict(df_feats_test_transformed)

Final Model Tuning

Code
final_workflow_r <-
  workflow_r |>
  finalize_workflow(best_params_r)

final_fit_r <-
  final_workflow_r |>
  fit(df_music_r_training)
Code
final_model_py = grid_search.best_estimator_
final_model_py.fit(df_feats_train_transformed, df_target_train)
DecisionTreeRegressor(ccp_alpha=1e-10, max_depth=2, min_samples_leaf=14,
                      random_state=901622)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Model Evaluation

Code
df_music_predictions <-
  predict(final_fit_r, df_music_r_testing) |>
  bind_cols(df_music_r_testing)

metrics(df_music_predictions, truth = anxiety, estimate = .pred)
# A tibble: 3 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard    2.68    
2 rsq     standard    0.000556
3 mae     standard    2.25    
Code
ggplot(df_music_predictions, aes(x = anxiety, y = .pred)) +
  geom_point(alpha = 0.5) +
  geom_abline(color = "red") +
  labs(
    x = "Actual Anxiety",
    y = "Predicted Anxiety",
    title = "Tuned Decision Tree Predictions"
  )

Code
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import numpy as np

# Predict on the test set
y_pred = final_model_py.predict(df_feats_test_transformed)

# Combine predictions and actuals into a DataFrame
df_music_predictions = pl.DataFrame(
    {"anxiety": df_target_test["anxiety"], "pred": y_pred}
)

# Calculate metrics
mse = mean_squared_error(df_music_predictions["anxiety"], df_music_predictions["pred"])
rmse = np.sqrt(mse)
mae = mean_absolute_error(df_music_predictions["anxiety"], df_music_predictions["pred"])
r2 = r2_score(df_music_predictions["anxiety"], df_music_predictions["pred"])

# Print metrics
rmse, mae, r2
(np.float64(2.766844053152648), 2.314042681316123, 0.0055172160779422)
Code
# Plot actual vs predicted
(
    ggplot(df_music_predictions, aes(x="anxiety", y="pred"))
    + geom_point(alpha=0.5)
    + geom_abline(color="red")
    + labs(
        x="Actual Anxiety",
        y="Predicted Anxiety",
        title="Tuned Decision Tree Predictions",
    )
)

Comparison & Takeaways

Conclusion