Seaborn - How to make beautiful charts and graphs for Data Analysts

Mar 1, 2024

Sometimes, we do not want to waste time connecting data to BI tools like Looker, Tableau, etc.


Especially when the task will require some sort of automation in the future.


Python + Seaborn are the saviors of Data or Product Analyst's precious time.


Recently I found out that my folder with generated charts using Seaborn surpassed 1000 images.


That's why I decided to share my go-to code to generate beautiful and visually appealing charts using Seaborn.


For showcasing purposes, I will use this SaaS dataset with user subscriptions for calculating MRR, Retention -> Kaggle Link.


Seaborn - Bar Charts


Lovely Bar Charts.


Everyone knows them, and everyone uses them for different aims.


I will not dig into the Data Visualization bible to discuss when you should use a Bar Chart.


Here is the default Bar Chart in Seaborn:


No doubt it is not a Tableau or Looker level.


However, Seaborn is built upon Matplotlib. Thus, it is easily customizable, which means we can do whatever we want here.


Here is the mine Upgraded version


It took 11-minutes to stylize it and make reusable.


What I've changed:

  • Formatted of X and Y axis lables

  • Moved legend outside

  • Added Title

  • Added formatted Value Labels on top of each bar

  • Removed border



Still not the ideal one, but it is definitely better.


What can be improved: Added Stacking option (currently can be done via joining to 2 separate charts), auto-formatting to the Y-label, render only each N date for X axis.


Let's take a look at the code of Upgraded Seaborn Barplot:

import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt


def generate_bar_chart(x:pd.Series,
                       y:pd.Series,
                       hue:pd.Series=None,
                       color_pallete:str='rocket_r',
                       size:tuple[int,int]=(12,8),
                       title:str=None,
                       save_img:bool=False):

    # set color pallete
    
    sns.set_palette(sns.color_palette(color_pallete))

    # set Chart size
    
    fig = plt.figure(figsize=size)

    # Barchart

    ax = sns.barplot(
            x=x,
            y=y,
            orient='v',
            saturation=1,
            width=0.75,
            hue=hue,
            errorbar=None
        )

    # Y-lables formatting
    
    ylabels = ['{:1.1f} $'.format(y) + 'K' for y in ax.get_yticks()/1000]
    ax.set_yticklabels(ylabels)

    # Adding Text Value labels on top of each bar
    
    show_values_on_bars(ax.axes)

    # If Hus Moving legend beyond the chart
    
    if hue is not None:
        sns.move_legend(
            ax, "upper left",
            bbox_to_anchor=(1.05, 1), borderaxespad=0)

    # Removing border around the chart
    
    plt.box(False)

    # Rotating X-lables
    
    ax.tick_params(axis='x', rotation=90)

    plt.title(title)
    plt.tight_layout()

    # Save to the current folder
    
    if save_img:
        img_name = f"{title}.jpg"
        plt.savefig(img_name, dpi='figure',bbox_inches='tight')

	 # Rendering chart
    plt.show()
    
	 # Returing chart ax object
    return ax


Pretty simple set of instructions. Nothing fancy.


However, the most interesting part is showing the formatted value labels.  


Here is the code to properly render formatted value labels:

def show_values_on_bars(axs):
    
    def _show_on_single_plot(ax):

        for p in ax.patches:
            _x = p.get_x() + p.get_width() / 2
            _y = p.get_y() + p.get_height() + p.get_height()/50
            
            # Formatting values             
            if _y >= 500_000:
              value = '{:1.1f}M $'.format(_y*0.000_001)
            elif _y >= 10_000:
              value = '{:1.0f}K $'.format(_y*0.001)
            elif _y >= 1_000:
              value = '{:1.1f}K $'.format(_y*0.001)
            elif _y >= 10:
              value = '{:1.0f} $'.format(_y)
            elif _y >= 1:
              value = '{:1.1f} $'.format(_y)
            else:
              value = '{:1.3f} $'.format(p.get_height())

            # Seting up and Formatting text lables
            ax.text(_x, _y, value,
                    ha="center", rotation=90,
                     size=9, fontdict=None
            )

    if isinstance(axs, np.ndarray):
        for idx, ax in np.ndenumerate(axs):
            _show_on_single_plot(ax)
    else:
        _show_on_single_plot(axs)


That's all for Bar Charts.


Seaborn - Heatmap Charts


Whenever I need to check cohort statistics like Retention, LTV, Churn, etc -> Heatmap chart is a nice choice to get a clue about each cohort's progress.


Before we start, let's take a look at the data. I will use a Dataframe with Retention data for each subscription.


Initial look:

signup_date_time - certain month cohort

months_from_signup - months from the signup_date_time passed

customers - how many unique customers are still active in the months_from_signup in the cohort of signup_date_time month

total_customers - overall quantity of unique customers that got into this cohort

retention - how many of total_customers survived till months_from_signup from this cohort


Pivoted and Annotations look:

Heatmap requires this type of table formatting for proper visualizing. It is pretty easy to do with DataFrame.pivot() method.


Okay, here is the default Heatmap chart:


You may say that if we remove value lables, then it will look better - definitely will.


However, without proper labelling it looks too tough to find insights, especially for a people that cannot process colors normally.


And, upgraded one:


What I've changed:

  • Moved X-axis to the top level

  • Changed color pallete

  • Added Title

  • Made rendering of every 5th columns value label

  • Removed lines between rows and cols


I've used a separate DataFrame with values for annotation to show values only for each of the 5 columns.


Code for Upgraded one:

import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

def generate_heatmap_chart(
                            data:pd.DataFrame,
                            annotation_data:pd.DataFrame,
                            cmap_pallete:str='Blues',
                            size:tuple[int,int]=(16,8),
                            text_size:int=8,
                            title:str=None,
                            save_img:bool=False):
    
    # set Chart size

    f, ax = plt.subplots(figsize=size)

    # Heatmap

    sns.heatmap(data, annot=annotation_data, fmt='',linewidths=0, ax=ax,
                                    cmap=cmap_pallete, annot_kws={"size": text_size})

    # Removing the ugly border around the chart area

    plt.box(False)
    ax.xaxis.tick_top()

    # Setting Title and layout

    plt.title(title, y=1.05, fontsize = 18)
    plt.tight_layout()

    # Save to the current folder
    
    if save_img:
        img_name = f"{title}.jpg"
        plt.savefig(img_name, dpi='figure',bbox_inches='tight')
        
    # Rendering chart
  
    plt.show()
    
    # Returing chart ax object
  
    return ax


Code for creating separate annotation DataFrame to show only N-th value lables:

def annotate_each_n_value(data:pd.DataFrame, each_n:int=5):
    cols = [col for col in data.columns if col not in data.iloc[:,::each_n].columns]
    data[cols] = ''
    return data


‍That's all, folks.‍