网站搜索

使用 Seaborn 在 Python 中可视化数据


如果您有一些使用 Python 进行数据分析的经验,那么您很可能已经制作了一些数据图来向其他人解释您的分析。您很可能会使用 Matplotlib 等库来生成这些内容。如果您想将统计可视化提升到一个新的水平,您应该掌握 Python seaborn 库来生成令人印象深刻的统计分析图来显示您的数据。

在本教程中,您将学习如何:

  • 做出明智的判断 Seaborn 是否满足您的数据可视化需求
  • 了解seaborn经典Python函数式接口的原理
  • 了解seaborn更现代的Python对象接口的原理
  • 使用seaborn的函数创建Python绘图
  • 使用seaborn的对象创建Python绘图

在开始之前,您应该熟悉 JupyterLab 中提供的 Jupyter Notebook 数据分析工具。尽管您可以使用您最喜欢的 Python 环境来学习此 seaborn 教程,但首选 Jupyter Notebook。您可能还想了解 pandas DataFrame 如何存储其数据。了解 pandas DataFrame 和 Series 之间的区别也将很有用。

现在,您可以立即开始学习如何使用 seaborn 来生成 Python 绘图。

Python Seaborn 入门

在使用seaborn之前,您必须安装它。打开 Jupyter Notebook 并在新的代码单元中输入 !python -m pip install seaborn 。当您运行单元时,seaborn 将安装。如果您在命令行上工作,请使用相同的命令,只是不带感叹号 (!)。一旦安装了seaborn,Matplotlib、pandas 和NumPy 也将可用。这很方便,因为有时您需要它们来增强您的 Python seaborn 绘图。

当然,在创建绘图之前,您需要数据。稍后,您将使用包含真实世界数据的不同公开可用数据集创建多个绘图。首先,您将使用seaborn 的创建者为您提供的一些示例数据。更具体地说,您将使用他们的 tips 数据集。该数据集包含特定餐厅服务员在几个月内收到的每条小费的数据。

使用seaborn创建条形图

假设您想查看一个条形图,显示服务员每天收到的平均小费金额。您可以编写一些 Python seaborn 代码来执行此操作:

In [1]: import matplotlib.pyplot as plt
   ...: import seaborn as sns
   ...:
   ...: tips = sns.load_dataset("tips")
   ...:
   ...: (
   ...:     sns.barplot(
   ...:         data=tips, x="day", y="tip",
   ...:         estimator="mean", errorbar=None,
   ...:     )
   ...:     .set(title="Daily Tips ($)")
   ...: )
   ...:
   ...: plt.show()

首先,将seaborn 导入到Python 代码中。按照惯例,您将其导入为 sns。尽管您可以使用任何您喜欢的别名,但 sns 是对库命名的虚构人物的致敬。

要在seaborn中处理数据,通常将其加载到pandas DataFrame中,尽管也可以使用其他数据结构。加载数据的常用方法是使用 pandas read_csv() 函数从磁盘上的文件读取数据。稍后您将看到如何执行此操作。

首先,因为您正在使用seaborn示例数据集之一,所以seaborn允许您使用其load_dataset()函数在线访问这些数据集。您可以在其 GitHub 存储库上查看免费可用文件的列表。要获得您想要的数据集,您需要做的就是向 load_dataset() 传递一个字符串,告诉它包含您感兴趣的数据集的文件的名称,它将被加载到pandas DataFrame 供您使用。

实际的条形图是使用seaborn的barplot()函数创建的。稍后您将了解有关不同绘图函数的更多信息,但现在,您已指定 data=tips 作为您希望使用的 DataFrame,并告诉该函数绘制 daytip 列。它们分别包含收到小费的日期和小费金额。

这里您应该注意的重要一点是,seaborn barplot() 函数与所有 seaborn 绘图函数一样,可以本能地理解 pandas DataFrame。要指定供他们使用的数据列,可以将其列名称作为字符串传递。无需编写 pandas 代码来识别要绘制的每个系列。

estimator="mean" 参数告诉seaborn 绘制x 每个类别的y 平均值。 这意味着您的绘图将显示每天的平均小费。您可以快速对其进行自定义,以使用常见的统计函数,例如 summaxminmedian,但 estimator="mean" 是默认值。默认情况下,该图还将显示误差线。通过设置errorbar=None,您可以抑制它们。

barplot() 函数将使用您传递给它的参数生成一个绘图,并且它将使用您想要查看的数据的列名称来标记每个轴。一旦 barplot() 完成,它会返回一个包含绘图的 matplotlib Axes 对象。要为绘图指定标题,您需要调用 Axes 对象的 .set() 方法并向其传递所需的标题。请注意,这一切都是直接在seaborn 中完成的,而不是Matplotlib。

在 IPython 和 PyCharm 等某些环境中,您可能需要使用 Matplotlib 的 show() 函数来显示绘图,这意味着您也必须将 Matplotlib 导入到 Python 中。如果您使用的是 Jupyter 笔记本,则不需要使用 plt.show() ,但使用它会删除绘图上方的一些不需要的文本。在 barplot() 末尾放置一个分号 (;) 也可以为您完成此操作。

当您运行代码时,结果图将如下所示:

正如您所看到的,服务员的每日平均小费在周末略有上升。看起来人们在放松的时候给的小费更多。

In [2]: tips["day"]
Out[2]:
0       Sun
1       Sun
2       Sun
3       Sun
4       Sun
       ...
239     Sat
240     Sat
241     Sat
242     Sat
243    Thur
Name: day, Length: 244, dtype: category
Categories (4, object): ['Thur', 'Fri', 'Sat', 'Sun']

如您所见,您的 day 列的数据类型为 category。另请注意,虽然您的原始数据以 Sun 开头,但 category 中的第一个条目是 Thur。在创建类别时,已按正确的顺序为您解释了日期。 read_csv() 函数不会执行此操作。

接下来,您将使用 Matplotlib 代码创建相同的绘图。这将使您看到两个库之间代码风格的差异。

使用 Matplotlib 创建条形图

现在看一下下面所示的 Matplotlib 代码。当你运行它时,它会产生与seaborn代码相同的输出,但代码远没有那么简洁:

In [3]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...:
   ...: tips = pd.read_csv("tips.csv")
   ...:
   ...: average_daily_tip = (
   ...:     tips
   ...:     .groupby("day")["tip"]
   ...:     .mean()
   ...: )
   ...:
   ...: days = ["Thur", "Fri", "Sat", "Sun"]
   ...: daily_averages = [
   ...:     average_daily_tip["Thur"],
   ...:     average_daily_tip["Fri"],
   ...:     average_daily_tip["Sat"],
   ...:     average_daily_tip["Sun"],
   ...: ]
   ...:
   ...: fig, ax = plt.subplots()
   ...: plt.bar(x=days, height=daily_averages)
   ...: ax.set_xlabel("day")
   ...: ax.set_ylabel("tip")
   ...: ax.set_title("Daily Tips ($)")
   ...:
   ...: plt.show()

这次,您混合使用了 pandas 和 Matplotlib,因此您必须导入两者。

首先,您使用 pandas read_csv() 函数读取 tips.csv 文件。然后,您必须使用 DataFrame 的 .groupby() 方法手动对数据进行分组,然后再使用 .mean() 计算每天的平均值。

接下来,您手动指定要绘制的数据以及绘制数据的顺序。当 read_csv() 读入数据时,它不会对其进行分类或应用任何顺序为你。作为补偿,您可以将要绘制的内容指定为 daysdaily_averages 列表。

要生成绘图,您可以使用 Matplotlib 的 bar() 函数并指定要绘制的两个数据系列。在本例中,您传递 x=daysheight=daily_averages。最后,应用轴标签和绘图标题。

如果运行此代码,您将看到与以前相同的绘图。

如果您想将绘图保存到外部文件,也许是为了在演示文稿或报告中使用它们,那么有几个选项可供您选择。

在许多环境中(例如 PyCharm),当您调用 plt.show() 时,绘图将显示在不同的窗口中。通常此窗口包含其自己的文件保存工具。

如果您使用的是 Jupyter 笔记本,则可以右键单击绘图并将其复制到剪贴板,然后再将其粘贴到报告或演示文稿中。

您还可以对代码进行一些调整,以便自动发生:

In [4]: import matplotlib.pyplot as plt
   ...:
   ...: import seaborn as sns
   ...:
   ...: tips = sns.load_dataset("tips")
   ...:
   ...: (
   ...:     sns.barplot(
   ...:         data=tips, x="day", y="tip",
   ...:         estimator="mean", errorbar=None,
   ...:     )
   ...:     .set_title("Daily Tips ($)")
   ...:     .figure.savefig("daily_tips.png")
   ...: )
   ...:
   ...: plt.show()

在这里,您使用了绘图的 .figure 属性,该属性允许您访问底层 Matplotlib 图窗,然后调用其 .savefig() 方法来保存它到 png 文件。默认为 png,但 .savefig() 还允许您传递常见的替代图形格式,包括 "jpeg" “pdf”“ps”

您可能已经注意到,条形图的标题是使用 .set_title("Daily Tips ($)") 方法设置的,而不是 .set(title="Daily Tips ($) ") 您之前使用过的方法。虽然您通常可以互换使用它们,但当您想使用 figure.savefig() 保存图形时,使用 .set_title("Daily Tips ($)") 更具可读性。

原因是 .set_title("Daily Tips ($)") 返回一个 matplotlib.text.Text 对象,其底层关联的 Figure 可以使用 .figure 属性访问对象。 这是使用 .savefig() 方法时保存的内容。

如果您使用 .set(title="Daily Tips ($)"),这仍然会返回一个 Text 对象。但是,它是列表中的第一个元素。要访问它,您需要使用 .set(title="Daily Tips ($)")[0].figure.savefig("daily_tips.png"),它的可读性较差。

希望这篇介绍能让您体验到seaborn。您已经看到seaborn 的Python 代码比Matplotlib 使用的代码相对清晰。这是可能的,因为 Seaborn 向您隐藏了 Matplotlib 的大部分复杂性。正如您在 barplot() 函数中看到的,seaborn 将数据作为 pandas DataFrame 传递,并且绘图函数理解其结构。

绘图函数是seaborn经典函数式界面的一部分,但它们只是故事的一半。

使用seaborn的一种更现代的方法是使用称为对象接口的东西。这提供了一种声明性语法,这意味着您可以使用各种对象定义您想要的内容,然后让seaborn将它们组合到您的图中。这使得创建绘图的方法更加一致,从而使界面更易于学习。它还隐藏了底层的 Matplotlib 功能,甚至比绘图功能还要多。

现在,您将继续学习如何使用每个界面。

了解seaborn的经典函数式接口

Seaborn 经典功能界面包含一组用于创建不同绘图类型的绘图函数。您之前使用 barplot() 函数时已经看到过这样的示例。函数式接口将其绘图函数分为几种主要类型。最常见的三种如下图所示:

第一列显示seaborn的关系图。这些可以帮助您了解数据集中的变量对如何相互关联。常见的例子是散点图和线图。例如,您可能想知道随着产品价格上涨利润如何变化。还有一个回归图类别,可以添加回归线,稍后您将看到。

第二列显示seaborn的分布图。这些可以帮助您了解数据集中的变量是如何分布的。常见的例子包括直方图和地毯图。例如,您可能想查看在国家考试中获得的每个成绩的计数。

第三列显示seaborn的分类图。这些还可以帮助您了解数据集中的变量对如何相互关联。然而,其中一个变量通常包含离散类别。常见的示例包括条形图和箱形图。您之前看到的按天分类的服务员的平均小费是分类图的一个示例。

您可能还注意到绘图函数有一个层次结构。您还可以将每个分类定义为图形级别轴级别函数。这提供了很大的灵活性。

图形级函数允许您绘制多个子图,每个子图显示不同类别的数据。例如,您可能想知道利润如何随多种产品的价格上涨而变化,但希望每种产品都有单独的子图。您在图形级函数中指定的参数适用于每个子图,这使它们具有一致的外观和感觉。 relplot()displot()catplot() 函数都是图形级别的。

相反,轴级函数允许您绘制单个图。这次,您向轴级函数提供的任何参数仅适用于该函数生成的单个图。每个轴级图在图表上用椭圆形表示。 lineplot()histplot()boxplot() 函数都是轴级函数。

接下来,您将仔细了解如何使用轴级函数来生成单个图。

使用轴级函数

当您需要的只是一个图时,您很可能会使用轴级函数。在此示例中,您将使用名为 cycle_crossings_apr_jun.csv 的文件。其中包含纽约不同桥梁的自行车道口数据。原始数据来自纽约市开放数据,但可下载材料中提供了副本。

您需要做的第一件事是将 Cycle_crossings_apr_jun.csv 文件读入 pandas DataFrame。为此,您可以使用 read_csv() 函数:

In [1]: import pandas as pd
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")

Crossings DataFrame 现在包含文件的全部内容。因此,数据可用于可视化。

假设您想查看文件中包含的三个月数据的最高温度和最低温度之间是否存在任何关系。实现此目的的一种方法是使用散点图。为此,Seaborn 提供了一个 scatterplot() 轴级函数:

In [2]: import matplotlib.pyplot as plt
   ...: import seaborn as sns
   ...:
   ...: (
   ...:     sns.scatterplot(
   ...:         data=crossings, x="min_temp", y="max_temp"
   ...:     )
   ...:     .set(
   ...:         title="Minimum vs Maximum Temperature",
   ...:         xlabel="Minimum Temperature",
   ...:         ylabel="Maximum Temperature",
   ...:     )
   ...: )
   ...:
   ...: plt.show()

您在此处使用 scatterplot() 函数的方式与使用 barplot() 的方式类似。再次提供 DataFrame 作为其 data 参数,然后提供要绘制的列。作为增强功能,您还可以调用 Matplotlib 的 Axes.set() 方法为绘图指定标题,并使用 xlabel 和 ylabel 来标记每个轴。默认情况下,没有标题,每个轴都根据其数据系列进行标记。使用 Axes.set() 允许大写。

结果图如下所示:

尽管每个图形级函数都需要自己的一组参数,并且您应该阅读 seaborn 文档以了解可用的参数,但大多数函数中都有一个强大的参数,称为 hue。此参数允许您为绘图上不同类别的数据添加不同的颜色。要使用它,您需要传入要应用着色的列的名称。

关系绘图函数还支持 stylesize 参数,允许您对每个点应用不同的样式和大小。这些可以进一步阐明你的情节。您决定更新您的绘图以包含它们:

In [3]: (
   ...:    sns.scatterplot(
   ...:         data=crossings, x="min_temp", y="max_temp",
   ...:         hue="month", size="month", style="month",
   ...:    )
   ...:     .set(
   ...:         title="Minimum vs Maximum Temperature",
   ...:         xlabel="Minimum Temperature",
   ...:         ylabel="Maximum Temperature",
   ...:     )
   ...: )
   ...:
   ...: plt.legend(title="Month")
   ...: plt.show()

尽管完全可以将 huesizestyle 设置为 DataFrame 中的不同列,但只需将它们全部设置为 "month ",您分别为每个月的数据点指定不同的颜色、大小和符号。您可以在下面更新的图中看到这一点:

尽管应用所有三个参数可能有点矫枉过正,但在这种情况下,您现在可以看到每个点属于哪个月份。您也在单个函数调用中完成了所有这些操作。

另请注意,seaborn 已为您提供了有用的图例。但是,图例的默认标题与传递给“hue”的数据系列相同。为了将其大写,您使用了 legend() 函数。

您将在本教程的后面部分看到更多轴级绘图函数,但现在是时候查看实际的图形级函数了。

使用图形级函数

有时您可能需要数据的多个子图,每个子图显示数据的不同类别。您可以手动创建多个绘图,但图形级函数会自动为您执行此操作。

与轴级函数一样,每个图形级函数都包含一些您应该学习如何使用的常用参数。 rowcol 参数允许您指定将在每个子图中显示的行或列数据系列。设置column参数会将每个子图放置在各自的列中,而设置row参数将为每个子图提供一个单独的行。

例如,假设您想查看每个月温度的单独散点图:

In [1]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: (
   ...:     sns.relplot(
   ...:         data=crossings, x="min_temp", y="max_temp",
   ...:         kind="scatter", hue="month", col="month",
   ...:     )
   ...:     .set(
   ...:         title="Minimum vs Maximum Temperature",
   ...:         xlabel="Minimum Temperature",
   ...:         ylabel="Maximum Temperature",
   ...:     )
   ...:     .legend.set_title("Month")
   ...: )
   ...:
   ...: plt.show()

与轴级函数一样,当使用图形级绘图函数时,您可以传入 DataFrame 并突出显示其中您有兴趣查看的系列。在此示例中,您使用了 relplot(),并通过设置 kind="scatter",告诉函数创建多个散点图子图。

hue 参数仍然存在,并且仍然允许您将不同的颜色应用于子图。事实上,建议您始终将其与图形级绘图功能一起使用,以强制seaborn为您创建图例。这澄清了每个子图。但是,默认图例标题将为小写的“month”

通过设置 col="month",每个子图将位于其自己的列中,每列代表一个单独的月份。这意味着您会看到一排。

图形级绘图函数(例如 relplot())会创建一个 FacetGrid 对象,在该对象上放置每个子图。要大写由图形级绘图创建的图例,请使用 FacetGrid 的 .legend 访问器来访问 .set_title()。然后,您可以为底层 FacetGrid 对象添加图例标题。

你的情节现在看起来像这样:

您创建了三个独立的散点图,一个用于每个月的数据。每个图都被赋予了单独的颜色,并准备了方便的图例,以便您更好地识别每个图向您展示的内容。

稍后您将看到更多函数接口的示例,但现在,是时候认识一下这个领域中相对较新的孩子了:seaborn 的对象接口。

介绍seaborn的当代对象界面

在本节中,您将了解seaborn对象接口的核心组件。这使用了更具声明性的语法,这意味着您可以通过创建和添加创建绘图所需的各个对象来分层构建绘图。以前,函数会为您完成此操作。

当您使用seaborn对象构建绘图时,您使用的第一个对象是Plot。该对象引用您正在绘制其数据的 DataFrame,以及其中您有兴趣查看其数据的特定列。

假设您想使用对象接口构建前面的温度散点图示例。 Plot 对象将是您的起点:

In [1]: import pandas as pd
   ...: import seaborn.objects as so
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: (
   ...:     so.Plot(
   ...:         data=crossings, x="min_temp", y="max_temp"
   ...:     )
   ...:     .show()
   ...: )

当您使用seaborn对象接口时,惯例是使用别名so将其导入Python。上面的代码重用了您之前创建的 crossings DataFrame。

要创建 Plot 对象,您可以调用其构造函数并传入包含您的数据的 DataFrame 以及包含您希望绘制的数据系列的列的名称。此处,xmin_tempymax_tempPlot 对象现在有可以使用的数据。

Plot 对象包含其自己的 .show() 方法来显示它。与前面讨论的 plt.show() 一样,在 Jupyter Notebook 中不需要它。

当您运行代码时,输出可能并不会让您兴奋:

如您所见,数据图无处可见。这是因为 Plot 对象只是绘图的背景。要查看某些内容,您需要通过向 Plot 对象添加一个或多个 Mark 对象来构建它。 Mark 对象是一系列子类的基类,每个子类代表数据可视化的不同部分。

接下来,向 Plot 对象添加一些内容以使其更有意义:

In [2]: (
   ...:     so.Plot(
   ...:         data=crossings, x="min_temp", y="max_temp"
   ...:      )
   ...:     .add(so.Dot())
   ...:     .label(
   ...:            title="Minimum vs Maximum Temperature",
   ...:            x="Minimum Temperature",
   ...:            y="Maximum Temperature",
   ...:     )
   ...:     .show()
   ...: )

要将 Plot 对象的数据显示为散点图,您需要向其中添加多个 Dot 对象。 Dot 类是 Mark 的子类,它将每个 xy 对显示为点。要添加 Dot 对象,请调用 Plot 对象的 .add() 方法,并传入要添加的对象。每次调用 .add() 时,您都会在 Plot 上添加新的细节层

作为最后一步,您可以标记绘图及其每个轴。为此,您可以调用 Plot.label() 方法。 title 参数为绘图提供标题,而 xy 参数分别标记关联的轴。

当您运行代码时,它看起来与您的第一个散点图相同,甚至包括标题和轴标签:

接下来,您可以通过将每一天分成单独的颜色和符号来改进您的绘图:

In [3]: (
   ...:     so.Plot(
   ...:         data=crossings, x="min_temp",
   ...:         y="max_temp", color="month",
   ...:     )
   ...:     .add(so.Dot(), marker="month")
   ...:     .label(
   ...:         title="Minimum vs Maximum Temperature",
   ...:         x="Minimum Temperature",
   ...:         y="Maximum Temperature",
   ...:         color=str.capitalize,
   ...:     )
   ...:     .show()
   ...: )

要将每个月的数据分隔为具有不同颜色的标记,请将要分隔其数据的列作为其 color 参数传递到 Plot 对象中。在这种情况下,color="month" 将为每个不同的月份分配不同的颜色。这提供了与您之前看到的函数接口使用的 hue 参数类似的功能。

要将不同的标记样式应用于代表每月的点,您需要将 marker 变量传递到定义 Dot 对象的同一图层。在本例中,您设置 marker="month" 来定义您希望区分其标记样式的系列。

您可以按照与之前的绘图相同的方式标记标题和轴。要为图例添加标签,您还可以使用 Plot 对象的 .label() 方法。通过传递 color=str.capitalize,您将把字符串的 .capitalize() 方法应用到 month 的默认标签,从而导致显示为月份xy 参数可以以相同的方式设置,但下划线将保留。您也可以设置 color="Month" 以获得相同的结果。

你的情节现在看起来像这样:

下一阶段是将每个月的数据分成单独的图:

In [4]: (
   ...:     so.Plot(
   ...:         data=crossings, x="min_temp",
   ...:         y="max_temp", color="month",
   ...:     )
   ...:     .add(so.Dot(), marker="month")
   ...:     .facet(col="month")
   ...:     .layout(size=(15, 5))
   ...:     .label(
   ...:         title="Minimum vs Maximum Temperature",
   ...:         x="Minimum Temperature",
   ...:         y="Maximum Temperature",
   ...:         color=str.capitalize,
   ...:     )
   ...:     .show()
   ...: )

要创建一组子图(每个一个),您可以使用Plot对象的.facet()方法。通过传入一个包含对您希望拆分的数据的引用的字符串(在本例中为 col="month"),您可以将每个月分隔到其自己的列中。您还使用了 Plot.layout() 方法将输出大小调整为 15 英寸 x 5 英寸的宽度。 这使得情节可读。

面向对象版本的绘图的最终版本现在如下所示:

正如您所看到的,每个子图仍然保留自己的颜色和标记样式。对象接口允许您通过对现有代码进行细微调整来创建多个子图,但不会使其变得更加复杂。对于对象,无需从头开始使用完全不同的函数。

决定使用哪个接口

Seaborn 对象界面旨在为您提供更直观和可扩展的数据可视化方式。它通过模块化来实现这一点。无论您想要可视化什么,所有绘图都以相同的 Plot 对象开始,然后使用其他 Mark 对象(例如 Dots)进行自定义。使用对象还可以使您的绘图代码看起来更统一。

对象接口还允许您创建更复杂的绘图,而无需使用更复杂的代码来执行此操作。随时添加对象的能力意味着您可以逐步构建一些非常令人印象深刻的绘图。

这个界面的灵感来自于图形语法。因此,您会发现它类似于 Vega-Altair、plotnine 和 R 的 ggplot2 等绘图库,它们都具有相同的灵感。

对象 API 也仍在开发中。开发商对此毫不掩饰。尽管 Seaborn 开发人员希望对象 API 成为其未来,但仍然值得关注文档每个版本页面中的新增内容,以了解这两个接口是如何改进的。不过,现在了解对象 API 将对您将来很有帮助。

这意味着您不应该完全放弃seaborn 绘图功能。它们仍然很受欢迎并且被广泛使用。如果您对他们为您提供的产品感到满意,那么就没有压倒性的理由去改变。此外,seaborn 开发人员仍然维护它们并根据他们认为合适的方式改进它们。它们绝不是过时的。

另请记住,虽然您个人可能更喜欢一种界面而不是另一种,但您可能需要将每种界面用于不同的绘图以满足您的要求。

在本教程的其余部分中,您将使用函数和对象创建一系列不同的绘图。再说一次,这不会详尽地涵盖您可以使用seaborn 做的所有事情,但它会向您展示更多有用的技术来帮助您。再次强调,请密切关注文档,了解有关该库可以做什么的更多详细信息。

使用函数创建不同的seaborn绘图

在本节中,您将学习如何使用 seaborn 的函数绘制一系列常见的绘图类型。当您完成这些示例时,请记住它们旨在说明使用 seaborn 的原则。 这些才是你真正应该掌握的学习要点,让你以后能够拓展知识面。

首先,您将看一些分类图的示例。

使用函数创建分类图

Seaborn 的分类图是一系列图,显示一组数值与一个或多个不同类别之间的关系。这使您可以看到不同类别的值如何变化。

假设您想要调查 cycle_crossings_apr_jun.csv 中详细介绍的所有四座桥梁的每日过境情况。尽管执行此操作所需的所有数据都已存在,但其格式并不完全适合通过桥进行分析:

In [1]: import pandas as pd
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...: crossings.head()
Out[1]:
         date        day  month  max_temp  min_temp  precipitation  \
0  01/04/2017   Saturday  April      46.0      37.0           0.00
1  02/04/2017     Sunday  April      62.1      41.0           0.00
2  03/04/2017     Monday  April      63.0      50.0           0.03
3  04/04/2017    Tuesday  April      51.1      46.0           1.18
4  05/04/2017  Wednesday  April      63.0      46.0           0.00

   Brooklyn  Manhattan  Williamsburg  Queensboro
0   606       1446          1915        1430
1  2021       3943          4207        2862
2  2470       4988          5178        3689
3   723       1913          2279        1666
4  2807       5276          5711        4197

问题是,要按桥梁类型对数据进行分类,您需要将每座桥梁的日常数据放在一列中。目前,每座桥都有一个单独的柱子。要对此进行排序,您需要使用 DataFrame.melt() 方法。这会将数据从当前的宽格式更改为所需的长格式。您可以使用以下代码来执行此操作:

In [2]: bridge_crossings = crossings.melt(
   ...:     id_vars=["day", "date"],
   ...:     value_vars=[
   ...:         "Brooklyn", "Manhattan",
   ...:         "Williamsburg", "Queensboro"
   ...:     ],
   ...:     var_name="Bridge",
   ...:     value_name="Crossings",
   ...: ).rename(columns={"day": "Day", "date": "Date"})
   ...:
   ...: bridge_crossings.head()
Out[2]:
         Day        Date    Bridge  Crossings
0   Saturday  01/04/2017  Brooklyn        606
1     Sunday  02/04/2017  Brooklyn       2021
2     Monday  03/04/2017  Brooklyn       2470
3    Tuesday  04/04/2017  Brooklyn        723
4  Wednesday  05/04/2017  Brooklyn       2807

要重新组织 DataFrame 以使每个桥的数据显示在同一列中,首先将 id_vars=["day", "date"] 传递给 .melt()。这些是标识符变量,用于标识正在重新格式化的数据。在这种情况下,每个 DayDate 值将用于识别当前图中和未来图中每个桥梁的数据。

您还传入一个值列表,您希望其 DayDate 数据出现在一列中。 在本例中,您将 value_vars 设置为桥梁列表,因为您希望列出每个桥梁交叉口值及其日期和日期。

为了使绘图标签更有意义并大写以保持整洁,您可以传入 var_nameval_name 参数以及值 BridgeCrossings分别。这将创建两个新列。 Bridge 列将包含所有桥梁名称,而 Crossings 列将包含每个日期和日期的交叉口。

最后,使用 DataFrame.rename() 方法将 daydate 列名称更新为 Day 和分别为日期。这将使您不必像以前那样更改各种绘图标签。

正如您从输出中看到的,新的 bridge_crossings DataFrame 的数据格式使您可以更轻松地使用。请注意,虽然仅显示了一些布鲁克林大桥数据,但其他桥梁在完整的 DataFrame 中列在其下方。

您可以使用数据生成条形图,显示一周中每一天所有四座桥梁的每日总穿越量:

In [3]: import matplotlib.pyplot as plt
   ...: import seaborn as sns
   ...:
   ...: sns.barplot(
   ...:     data=bridge_crossings,
   ...:     x="Day", y="Crossings",
   ...:     hue="Bridge", errorbar=None,
   ...:     estimator="sum",
   ...: )
   ...: plt.show()

此代码类似于前面分析提示数据的条形图示例。这次,您使用 hue 参数对每座桥梁的数据进行不同的着色,并通过设置 estimator="sum" 绘制每天的交叉总数。这是您希望用来计算总交叉点的函数的名称。

结果图如下所示:

正如您所看到的,条形图包含七组,每组四个条形,一周中每一天的每座桥都有一个。

从图中可以看出,威廉斯堡大桥似乎是整体上最繁忙的,其中星期三是最繁忙的一天。您决定进一步调查此事。您决定为三个月的数据中的每一个生成威廉斯堡周三数据的箱线图。这将为您提供一些数据统计分析:

In [4]: wednesday_crossings = crossings.loc[
   ...:     crossings.day.isin(["Wednesday"])
   ...: ].rename(columns={"month": "Month"})
   ...:
   ...: (
   ...:     sns.boxplot(
   ...:         data=wednesday_crossings, x="day",
   ...:         y="Williamsburg", hue="Month",
   ...:     )
   ...:     .set(xlabel=None)
   ...: )
   ...:
   ...: plt.show()

这次,您使用轴级 boxplot() 函数来生成绘图。正如您所看到的,它的参数与您已经看到的类似。 xy 参数告诉函数要使用哪些数据,而设置 hue="month" 则为每个月提供单独的箱线图。您还可以在绘图上设置xlabel=None。这将删除默认的day标签,但保留Wednesday

你的情节看起来像这样:

对于三个月中的每个月,每个框的高度显示四分位数范围,而穿过每个框的中心线显示中值。每个框外的水平须线显示上四分位数和下四分位数,而圆圈则显示异常值。

使用您迄今为止学到的原则和seaborn文档,您可能想尝试以下练习:

任务 1:看看是否可以仅为周末数据创建多个条形图,每天都在单独的图上但在同一行中。每个子图应显示每座桥梁的最高交叉口数量。

任务 2:看看您是否可以连续绘制三个箱线图,其中仅包含每周三布鲁克林大桥的单独交叉口。

任务 1 解决方案 这是一种使用条形图分别绘制每座桥梁周六和周日最大交叉口的方法:

In [1]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: bridge_crossings = crossings.melt(
   ...:     id_vars=["day", "date"],
   ...:     value_vars=[
   ...:         "Brooklyn", "Manhattan",
   ...:         "Williamsburg", "Queensboro",
   ...:     ],
   ...:     var_name="Bridge",
   ...:     value_name="Crossings",
   ...: ).rename(columns={"day": "Day", "date": "Date"})
   ...:
   ...: weekend = bridge_crossings.loc[
   ...:     bridge_crossings.Day.isin(["Saturday", "Sunday"])
   ...: ]
   ...:
   ...: (
   ...:     sns.catplot(
   ...:         data=weekend, x="Day", y="Crossings",
   ...:         hue="Bridge", col="Day", errorbar=None,
   ...:         estimator="max", kind="bar",
   ...:     ).set(xlabel=None)
   ...: )
   ...:
   ...: plt.show()

与之前一样,您使用 read_csv() 读取原始数据,然后使用 .melt() 旋转数据,以便每座桥梁的交叉口显示在一列中。

然后,您使用 .isin() 仅提取周末数据。获得此信息后,您可以使用catplot()函数来创建绘图。通过传入 col="Day",每天的数据被分成不同的子图。使用 estimator="max",您可以确保只绘制每日最高交叉点。 kind="bar" 参数为您生成所需的绘图类型。

任务 2 解决方案 为每个月的布鲁克林大桥星期三交叉口创建箱线图的一种方法如下所示:

In [2]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: wednesday = (
   ...:     crossings
   ...:     .loc[crossings.day.isin(values=["Wednesday"])]
   ...:     .rename(columns={"month": "Month"})
   ...: )
   ...:
   ...: (
   ...:     sns.catplot(
   ...:         data=wednesday, x="day", y="Brooklyn",
   ...:         col="Month", kind="box",
   ...:     )
   ...:     .set(xlabel=None)
   ...: )
   ...:
   ...: plt.show()

这次,读入数据后,您使用 .isin() 仅提取星期三的数据。一旦你有了这个,你就可以使用catplot()来生成绘图。通过传入 x="day",您可以确保将每天的数据放置到不同的子图上,而通过设置 y="Brooklyn",您可以确保仅将绘制了布鲁克林大桥的数据。要分隔月份,您可以设置 col="Month",同时设置 kind="box" 会生成箱线图。

接下来,您将看一些分布图的示例。

使用函数创建分布图

Seaborn 的分布图是一系列图,可让您查看一系列样本中的数据分布。这可以揭示数据的趋势或其他见解,例如允许您查看数据是否符合常见的统计分布。

最常见的分布图类型之一是histplot()。这允许您创建直方图,通过将数据分组到不同的范围或容器中来可视化数据的分布非常有用。

在本部分中,您将使用 cereals.csv 文件。该文件包含来自多个制造商的各种流行早餐麦片的数据。原始数据来自 Kaggle,可根据知识共享许可免费获取。

您需要做的第一件事是将谷物数据读入 DataFrame:

In [1]: import pandas as pd
   ...:
   ...: cereals_data = (
   ...:     pd.read_csv("cereals_data.csv")
   ...:     .rename(columns={"rating": "Rating"})
   ...: )
   ...:
   ...: cereals_data.head()
Out[1]:
                      name   manufacturer  calories  protein  fat  ...  \
0  Apple Cinnamon Cheerios  General Mills       110        2    2  ...
1                  Basic 4  General Mills       130        3    2  ...
2                 Cheerios  General Mills       110        6    2  ...
3    Cinnamon Toast Crunch  General Mills       120        1    3  ...
4                 Clusters  General Mills       110        3    2  ...

     vitamins  shelf  weight  cups     Rating
0          25      1    1.00  0.75  29.509541
1          25      3    1.33  0.75  37.038562
2          25      1    1.00  1.25  50.764999
3          25      2    1.00  0.75  19.823573
4          25      3    1.00  0.50  40.400208

首先,假设您想详细了解不同谷物食品的谷物评级有何不同。一种方法是创建一个直方图,显示每种谷物的评分计数分布。该数据包含包含此信息的 Rating 列。您可以使用 histplot() 函数创建绘图:

In [2]: import matplotlib.pyplot as plt
   ...: import seaborn as sns
   ...:
   ...: (
   ...:     sns.histplot(data=cereals_data, x="Rating", bins=10)
   ...:     .set(title="Cereal Ratings Distribution")
   ...: )
   ...:
   ...: plt.show()

与您使用过的所有轴级函数一样,您可以将要使用的 DataFrame 分配给 histplot()data 参数。 x 参数包含您要计数的值。在此示例中,您决定将数据分组到十个大小相等的箱中。这将在你的图中产生十列:

正如您所看到的,谷物评级的分布偏向低端。这些谷物最受欢迎的评级是三十多岁。

另一种常见的分布图类型是核密度估计(KDE)图。这使您可以分析连续数据并估计其中出现任何值的概率。要创建早餐麦片分析的 KDE 曲线,您可以使用以下代码:

In [3]: (
   ...:     sns.kdeplot(data=cereals_data, x="Rating")
   ...:     .set(title="Cereal Ratings KDE Curve")
   ...: )
   ...:
   ...: plt.show()

这将分析 cereals_data 数据系列中的每个 Rating 值,并根据其出现的概率绘制 KDE 曲线。传递给 kdeplot() 函数的各种参数与您之前使用的 histplot() 中的参数具有相同的含义。生成的 KDE 曲线如下所示:

该曲线进一步证明谷物评级的分布偏向低端。如果您随机选择数据集中的任何早餐麦片,它很可能包含 40 左右的评分。

地毯图是另一种用于可视化数据分布密度的图。它包含一组垂直线,就像捻绒地毯中的捻线一样,但其间距随它们所代表的数据的分布密度而变化。更常见的数据由更紧密排列的线表示,而不太常见的数据由更宽间距的线表示。

地毯图本身就是一个独立的图,但它通常会添加到另一个更明确的图中。您可以通过确保两个函数引用相同的底层 Matplotlib 图形来做到这一点。为此,请确保诸如 plt.figure() 之类的代码(创建单独的底层 Matplotlib 图形对象)不会出现在每对函数之间

假设您想通过在 KDE 图上创建地毯图来可视化交叉数据:

In [4]: sns.kdeplot(data=cereals_data, x="Rating")
   ...:
   ...: (
   ...:     sns.rugplot(
   ...:         data=cereals_data, x="Rating",
   ...:         height=0.2, color="black",
   ...:     )
   ...:     .set(title="Cereal Rating Distribution")
   ...: )
   ...:
   ...: plt.show()

kdeplot() 函数与您之前使用的函数相同。此外,您还使用 rugplot() 函数添加了新的地毯图。两者的 datax 参数相同,以确保它们都匹配。通过设置 height=0.2,地毯图将占据图高度的 20%,而通过设置 color="black",它将更加突出。

情节的最终版本如下所示:

正如您所看到的,随着 KDE 曲线值的增加,地毯图的纤维变得更加捆绑在一起。 相反,KDE 值越低,地毯图的纤维就越稀疏。

使用您迄今为止学到的原则和seaborn文档,您可能想尝试以下练习:

任务 1:制作一个显示谷物评级分布的直方图,以便每个制造商都有一个单独的条形图。保持相同的十个垃圾箱。

任务 2:看看是否可以仅使用一个函数将 KDE 图叠加到原始评分直方图上。

任务 3:更新任务 1 的答案,以便每个制造商的卡路里数据及其自己的 KDE 曲线显示在单独的图上。

任务 1 解决方案 以下是绘制每个制造商的谷物评级分布的一种方法:

In [1]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: cereals_data = (
   ...:     pd.read_csv("cereals_data.csv")
   ...:     .rename(columns={"rating": "Rating"})
   ...: )
   ...:
   ...: sns.histplot(
   ...:     data=cereals_data, x="Rating",
   ...:     bins=10, hue="manufacturer",
   ...:     multiple="dodge",
   ...: )
   ...:
   ...: plt.show()

读入数据后,您几乎可以调整之前绘制所有制造商的分布时使用的代码。通过将 histplot() 函数的 huemultiple 参数设置为 "manufacturer""dodge" 分别,您使用每个制造商的单独条形图分隔数据,并确保它们不重叠。

任务 2 解决方案 叠加 KDE 图的一种方法如下所示:

In [2]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: cereals_data = (
   ...:     pd.read_csv("cereals_data.csv")
   ...:     .rename(columns={"rating": "Rating"})
   ...: )
   ...:
   ...: sns.histplot(
   ...:    data=cereals_data,
   ...:    x="Rating", kde=True, bins=10,
   ...: )
   ...: plt.show()

您还可以通过对原始评分直方图进行小幅更新来解决此问题。您所需要做的就是将其 kde 参数设置为 True。 这将添加 KDE 图。

任务 3 解决方案 您可以通过以下一种方法分别绘制每个制造商的评级分布及其 KDE 曲线:

In [3]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: cereals_data = (
   ...:     pd.read_csv("cereals_data.csv")
   ...:     .rename(columns={
   ...:         "rating": "Rating",
   ...:         "manufacturer": "Manufacturer",
   ...:     })
   ...: )
   ...:
   ...: sns.displot(
   ...:     data=cereals_data, x="Rating",
   ...:     bins=10, hue="Manufacturer",
   ...:     kde=True, col="Manufacturer",
   ...: )
   ...:
   ...: plt.show()

此解决方案与任务二类似,只是您使用图形级 displot() 函数而不是轴级 histplot() 函数。参数类似,只是将 huecolumn 参数设置为 manufacturer。这些将分别将每个制造商的数据分成单独的颜色和绘图。默认情况下会创建直方图,但您也可以显式指定 kind="hist"

接下来,您将看一些关系图的示例。

使用函数创建关系图

Seaborn 的关系图是一系列图,可让您研究两组数据之间的关系。您之前在创建散点图时看到过其中一个示例。

另一种常见的关系图是线图。线图将信息显示为一组用直线段连接的数据标记点。它们通常用于可视化时间序列。要在seaborn中创建一个,您可以使用lineplot()函数。

在本部分中,您将重用之前用作关系图基础的 crossingsbridge_crossings DataFrame。

假设您想查看 4 月至 6 月这三个月中布鲁克林大桥每日过桥量的趋势。线图是向您展示这一点的一种方式:

In [1]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: sns.set_theme(style="darkgrid")
   ...:
   ...: (
   ...:     sns.lineplot(data=crossings, x="date", y="Brooklyn")
   ...:     .set(
   ...:         title="Brooklyn Bridge Daily Crossings",
   ...:         xlabel=None,
   ...:     )
   ...: )
   ...:
   ...: plt.xticks(
   ...:     ticks=[
   ...:         "01/04/2017", "01/05/2017", "01/06/2017", "30/06/2017"
   ...:     ],
   ...:     rotation=45,
   ...: )
   ...:
   ...: plt.show()

为了增强绘图的外观,您可以调用seaborn的set_theme()函数并设置darkgrid背景主题。 这为绘图提供了阴影背景和白色网格,以便于阅读。请注意,此设置将应用于所有后续绘图,除非您将其重置回默认的 white 值。

与所有seaborn 函数一样,您首先在DataFrame 中传递lineplot()。线图将显示时间序列,因此 x 值被指定为 date 系列,而 y 值被指定为 布鲁克林系列。这些参数足以绘制可视化效果。

x 系列包含超过 90 个值,这意味着在绘制绘图时它们将被压在一起且无法读取。为了澄清这一点,您决定使用 Matplotlib xticks() 函数来旋转并仅显示三个月中每个月的开始日期,以及六月的最后一天。您的读者可以使用此信息以及背景网格来推断其余日期。您还可以为绘图指定标题并删除其 xlabel

您创建的情节如下所示:

正如您所看到的,线图绘制了每个每日交叉值,并将这些值用直线段连接在一起。您可能会惊讶地发现桥梁交叉口高度的变化。有时,过境点数量少于 500 人次,而另一些日子则接近 4,000 人次。

使用您迄今为止学到的原则和seaborn文档,您可能想尝试以下练习:

任务 1:使用适当的数据集,生成一个单线图,显示 4 月至 6 月所有桥梁的交叉口。

任务 2:通过为每个桥梁创建单独的子图来阐明任务 1 的解决方案。

任务 1 解决方案 这是在单线图上绘制桥梁交叉口的一种方法:

In [1]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: bridge_crossings = crossings.melt(
   ...:     id_vars=["day", "date"],
   ...:     value_vars=[
   ...:         "Brooklyn", "Manhattan",
   ...:         "Williamsburg", "Queensboro",
   ...:     ],
   ...:     var_name="Bridge",
   ...:     value_name="Crossings",
   ...: ).rename(columns={"day": "Day", "date": "Date"})
   ...:
   ...: (
   ...:     sns.lineplot(
   ...:         data=bridge_crossings, x="Date", y="Crossings",
   ...:         hue="Bridge", style="Bridge",
   ...:     )
   ...:     .set_title("Daily Bridge Crossings")
   ...: )
   ...:
   ...: plt.xticks(
   ...:     ticks=[
   ...:         "01/04/2017", "01/05/2017", "01/06/2017", "30/06/2017"
   ...:     ],
   ...:     rotation=45,
   ...: )
   ...:
   ...: plt.show()

您再次读入数据并使用 DataFrame 的 .melt() 方法对其进行透视,将每个桥的数据放在同一列中。然后使用 lineplot() 函数绘制绘图。通过将 huestyle 设置为 "Bridge",您可以确保每个桥的数据显示为具有不同颜色的单独行,并且外貌。为了使 x 轴不那么拥挤,您将其刻度设置为显示的四个日期位置并将它们旋转 45 度。

任务 2 解决方案 分离之前的线图的一种方法如下所示:

In [2]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: bridge_crossings = crossings.melt(
   ...:     id_vars=["day", "date"],
   ...:     value_vars=[
   ...:         "Brooklyn", "Manhattan",
   ...:         "Williamsburg", "Queensboro",
   ...:     ],
   ...:     var_name="Bridge",
   ...:     value_name="Crossings",
   ...: ).rename(columns={"day": "Day", "date": "Date"})
   ...:
   ...: sns.relplot(
   ...:     data=bridge_crossings, kind="line",
   ...:     x="Date", y="Crossings",
   ...:     hue="Bridge", col="Bridge",
   ...: )
   ...:
   ...: plt.xticks(
   ...:     ticks=[
   ...:         "01/04/2017", "01/05/2017", "01/06/2017", "30/06/2017"
   ...:     ],
   ...:     rotation=45,
   ...: )
   ...:
   ...: plt.show()

此代码与任务 1 的解决方案类似,只是这次您使用 relplot() 函数。 通过设置 col="Bridge",您可以将每个桥的数据分成自己的绘图。

接下来,您将看一些回归图的示例。

使用函数创建回归图

Seaborn 的回归图是一系列图,可让您研究两组数据之间的关系。它们在数据集之间生成回归分析,帮助您可视化它们的关系。

两个轴级回归绘图函数是 regplot()residplot() 函数。它们分别产生回归分析和回归分析的残差。

在本节中,您将继续使用之前使用的交叉 DataFrame。

之前您使用了 scatterplot() 函数创建了比较最低和最高温度的散点图。如果您使用 regplot() 来代替,您将产生相同的结果,只是在其上叠加了一条线性回归线:

In [1]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: (
   ...:     sns.regplot(
   ...:         data=crossings, x="min_temp",
   ...:         y="max_temp", ci=95,
   ...:     )
   ...:     .set(
   ...:         title="Regression Analysis of Temperatures",
   ...:         xlabel="Minimum Temperature",
   ...:         ylabel="Maximum Temperature",
   ...:     )
   ...: )
   ...:
   ...: plt.show()

和以前一样,regplot() 函数需要一个 DataFrame,以及要绘制的 xy 系列。这足以绘制散点图以及线性回归线。生成的回归图如下所示:

线周围的阴影是置信区间。默认情况下,该值设置为 95%,但可以通过相应设置 ci 参数进行调整。您可以通过设置 ci=None 来删除置信区间。

使用 regplot() 最令人沮丧的方面之一是它不允许您将回归方程或 R 平方值插入到绘图中。尽管 regplot() 内部知道这些,但它不会向您透露它们。如果您想查看方程式,则必须单独计算并显示它。

为此,您可以使用 scikit-learn 库中的 LinearRegression 类。此类的对象允许您计算两个变量之间的普通最小二乘线性回归。

要使用它,您必须首先使用 !python -m pip install scikit-learn 安装 scikit-learn。和以前一样,如果您在命令行中工作,则不需要感叹号 (!)。安装 scikit-learn 库后,您可以执行回归:

In [2]: from sklearn.linear_model import LinearRegression
   ...:
   ...: x = crossings.loc[:, ["min_temp"]]
   ...: y = crossings.loc[:, "max_temp"]
   ...:
   ...: model = LinearRegression()
   ...: model.fit(x, y)
   ...:
   ...: r_squared = f"R-Squared: {model.score(x, y):.2f}"
   ...: best_fit = (
   ...:     f"y = {model.coef_[0]:.4f}x"
   ...:     f"{model.intercept_:+.4f}"
   ...: )
   ...:
   ...: ax = sns.regplot(
   ...:     data=crossings, x="min_temp", y="max_temp",
   ...:     line_kws={"label": f"{best_fit}\n{r_squared}"},
   ...: )
   ...:
   ...: ax.set_xlabel("Minimum Temperature")
   ...: ax.set_title("Regression Analysis of Temperatures")
   ...: ax.set_ylabel("Maximum Temperature")
   ...: ax.legend()
   ...:
   ...: plt.show()

首先,从 sklearn.linear_model 导入 LinearRegression。正如您很快就会看到的,您将需要它来执行线性回归计算。然后,您创建一个 pandas DataFrame 和一个 pandas Series。您的 x 是一个包含 min_temp 列数据的 DataFrame,而 y 是一个包含 max_temp 列数据的系列数据。您可能会在多个功能上进行回归,这就是为什么 x 被定义为具有列列表的 DataFrame 的原因。

接下来,您创建一个 LinearRegression 实例,并使用 .fit() 将两个数据集传递给它。这将为您执行实际的回归计算。默认情况下,它使用普通最小二乘法 (OLS) 来执行此操作。

创建并填充 LinearRegression 实例后,其 .score() 方法将计算 R 平方(或确定系数)值。这可以衡量最佳拟合线与实际值的接近程度。在您的分析中,R 平方值 0.78 表示最佳拟合线与实际值之间的准确度为 78%。您将其存储在名为 r_squared 的字符串中,以便稍后进行绘图。您对整洁度值进行四舍五入。

LinearRegression 实例还计算线性回归线的斜率及其 y 截距。它们分别存储在 .coef_[0].intercept_ 属性中。

要绘制绘图,您可以像以前一样使用 regplot() 函数,但使用其 line_kws 参数来定义回归的 label 属性线。它作为 Python 字典传入,其键是您要设置的参数,其值是该参数的值。在本例中,它是一个字符串,其中包含您之前计算的 best_fit 方程和 r_squared 值。

您将 regplot()(一个 Matplotlib Axes 对象)分配给名为 ax 的变量,以允许您给出绘图及其轴标题。最后,您使用 .legend() 方法显示其标签的内容,即线性回归方程和 R 平方值。

您更新后的绘图现在如下所示:

正如您所看到的,数据点的最佳拟合直线方程已添加到您的图中。

使用您迄今为止学到的原则和seaborn文档,您可能想尝试以下练习:

任务 1:重做之前的回归图,但这次创建一个图,显示三个月中每个月的单独回归线以及方程。

任务 2:使用适当的图形级函数为每个月创建单独的回归图。

任务 3:看看您是否可以将正确的方程添加到您在任务 2 中创建的三个图中的每一个上。提示:研究 FacetGrid.map_dataframe()方法。

任务 1 解决方案 将每个月的每个回归绘制在同一个图上的一种方法是:

In [1]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...: from sklearn.linear_model import LinearRegression
   ...:
   ...: def calculate_regression(month, data):
   ...:     x = data.loc[:, ["min_temp"]]
   ...:     y = data.loc[:, "max_temp"]
   ...:     model = LinearRegression()
   ...:     model.fit(x, y)
   ...:     r_squared = (
   ...:         f"R-Squared: {model.score(x, y):.2f}"
   ...:     )
   ...:     best_fit = (
   ...:         f"{month}\ny = {model.coef_[0]:.4f}x"
   ...:         f"{model.intercept_:+.4f}"
   ...:     )
   ...:     return r_squared, best_fit
   ...:
   ...: def drawplot(month, crossings):
   ...:     monthly_crossings = crossings[
   ...:         crossings.month == month
   ...:     ]
   ...:     r_squared, best_fit = calculate_regression(
   ...:         month, monthly_crossings
   ...:     )
   ...:
   ...:     ax = sns.regplot(
   ...:         data=monthly_crossings, x="min_temp",
   ...:         y="max_temp", ci=None,
   ...:         line_kws={"label": f"{best_fit}\n{r_squared}"},
   ...:     )
   ...:     ax.set_title(
   ...:         "Regression Analysis of Temperatures"
   ...:     )
   ...:     ax.legend()
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: months = ["April", "May", "June"]
   ...: for month in months:
   ...:     drawplot(month, crossings)
   ...:
   ...: plt.show()

与前面的示例一样,您需要手动计算每条线的回归方程。为此,您需要创建一个 calculate_regression() 函数,该函数接受一个表示要确定行的月份的字符串,以及一个包含数据的 DataFrame。该函数的主体使用与前面的示例类似的代码来计算线性回归方程。

回归图再次使用seaborn的regplot()函数生成。您还将代码放入 drawplot() 函数中,以便可以多次调用它,每个绘制的月份调用一次。这也与您之前看到的示例类似。

主代码读取源数据,然后在所需的三个月中的每个月的 for 循环中调用 drawplot()。它传入一个字符串来标识月份以及包含数据的 DataFrame。

任务 2 解决方案 将每个月的每个回归绘制在同一个图上的一种方法是:

In [2]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: sns.lmplot(
   ...:     data=crossings, x="min_temp",
   ...:     y="max_temp", col="month",
   ...: )
   ...:
   ...: plt.show()

这次,您使用seaborn的lmplot()函数来为您进行绘图。要按月分隔每个子图,请设置 col="month"

任务 3 解决方案 将每个月的每个回归绘制在同一个图上的一种方法是:

In [3]: import matplotlib.pyplot as plt
   ...: import pandas as pd
   ...: import seaborn as sns
   ...: from sklearn.linear_model import LinearRegression
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: def regression_equation(data, **kws):
   ...:     x = data.loc[:, ["min_temp"]]
   ...:     y = data.loc[:, "max_temp"]
   ...:     model = LinearRegression()
   ...:     model.fit(x, y)
   ...:     r_squared = (
   ...:         f"R-Squared: {model.score(x, y):.2f}"
   ...:     )
   ...:     best_fit = (
   ...:         f"y = {model.coef_[0]:.4f}x"
   ...:         f"{model.intercept_:+.4f}"
   ...:     )
   ...:     ax = plt.gca()  # Get current Axes.
   ...:     ax.text(
   ...:         0.1, 0.6,
   ...;         f"{best_fit}\n{r_squared}",
   ...:         transform=ax.transAxes,
   ...:     )
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: sns.lmplot(
   ...:     data=crossings, x="min_temp",
   ...:     y="max_temp", col="month",
   ...: ).map_dataframe(regression_equation)
   ...:
   ...: plt.show()

和以前一样,您使用 lmplot() 函数来创建绘图。您设置 col="month" 以确保为每个月生成单独的图。接下来,您必须手动计算每个月数据的回归方程。您可以在 regression_equation() 函数中进行计算。该函数的标头显示它采用 DataFrame 作为其 data 参数以及通过关键字传递的一系列其他参数。

在这里,您需要为您想要其方程的每个月的数据调用一次 regression_equation()。为此,您可以使用seaborn的FacetGrid.map_dataframe()方法。请记住,FacetGrid 是放置每个子图的对象,它是由 lmplot() 创建的。

通过调用 .map_dataframe() 并传递 regression_equation 作为其参数,将为每个月调用 regression_equation() 函数。它传递的 data 最初传递给 lmplot(),但在 col="month" 上进行过滤。然后,它使用这些来计算每个单独月份数据的回归方程。

接下来,您将把注意力转向使用 seaborn 的对象接口。

使用对象创建seaborn数据图

之前您已经了解了如何使用seaborn的Plot对象作为绘图的背景,而您必须使用一个或多个Mark对象为其提供内容。 在本节中,您将学习如何使用更多这些对象的原理,以及如何使用其他一些常见的seaborn对象。与使用函数的部分一样,记住要集中精力理解原理。详细信息在文档中。

使用主要数据可视化对象

Seaborn 对象接口包括几个 Mark 对象,包括 LineBarArea,以及 Mark 对象。Dot 你已经见过了。尽管其中每个都可以单独生成绘图,但您也可以将它们组合起来以生成更复杂的可视化效果。

例如,假设您想要准备一个绘图,以便可视化 路口 数据第一周的最低温度:

In [1]: import pandas as pd
   ...: import seaborn.objects as so
   ...:
   ...: crossings = pd.read_csv(
   ...:     "cycle_crossings_apr_jun.csv",
   ...:     parse_dates=["date"],
   ...:     dayfirst=True,
   ...: )
   ...:
   ...: first_week = (
   ...:     crossings
   ...:     .loc[crossings.month == "April"]
   ...:     .sort_values(by="date")
   ...:     .head(7)
   ...: )
   ...:
   ...: (
   ...:     so.Plot(data=first_week, x="day", y="min_temp")
   ...:     .add(so.Line(color="black", linewidth=3, marker="o"))
   ...:     .add(so.Bar(color="green", fill=False, edgewidth=3))
   ...:     .add(so.Area(color="yellow"))
   ...:     .label(
   ...:         x="Day", y="Temperature",
   ...:         title="Minimum Temperature",
   ...:     )
   ...:     .show()
   ...: )

您确保 date 列被解释为日期,以便您可以计算 4 月的前 7 天。您可以通过过滤 crossings 来获取 4 月数据、按 date 排序并使用 .head(7) 创建 first_week仅获取前七行,其中包含第一周的数据。

与使用对象创建的所有 seaborn 绘图一样,您必须首先创建一个包含对所需数据的引用的 Plot 对象。在这种情况下,您必须为 data 提供 first_week DataFrame 以及其中的 daymin_temp 系列、xy 分别。这些值将可供您稍后添加到绘图中的任何对象使用。

要向绘图添加内容,您可以使用 Plot.add() 方法并传入您想要添加的一个或多个对象。每次调用 Plot.add() 时,都会将其参数添加到 Plot 对象的单独中。 在本例中,您调用了 .add() 三次,因此将添加三个单独的层。

第一层包含一个 Line 对象,您可以使用它在绘图上绘制线条并创建线图。通过传入 colorlinewidthmarker 参数,您可以定义 Line 对象的外观。绘图上将出现一组连接相邻数据点的线。

第二层包含一个 Bar 对象。这些用于条形图中。您再次指定一些参数来定义条形图的外观。然后将它们应用到绘图上的每个条形图上。

最后一层添加一个 Area 对象。这提供了数据下方的阴影。在本例中,它会是yellow,因为您已将其指定为它的color 属性。

最后,调用 Plot.label() 方法,为绘图提供标题和大写标签轴。

你的情节看起来像这样:

正如您所看到的,所有三个对象都已放置在绘图上。允许您向 Plot 对象添加单独的对象,为您的最终可视化外观提供了极大的灵活性。您不再受函数如何决定绘图外观的限制。然而,正如您在这里所看到的,您可能会在没有意识到的情况下做得过度。

使用 MoveStat 对象增强绘图

接下来,假设您想要分析三个月内每天的中位最高气温。 为此,您需要使用seaborn的StatMove对象类型:

In [2]: import pandas as pd
   ...: import seaborn.objects as so
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: (
   ...:     so.Plot(
   ...:         data=crossings, x="month",
   ...:         y="max_temp", color="day",
   ...:     )
   ...:     .add(
   ...:         so.Bar(),
   ...:         so.Agg(func="median"),
   ...:         so.Dodge(gap=0.1),
   ...:     )
   ...:     .label(
   ...:         x=None, y="Temperature",
   ...:         title="Median Temperature", color="Day",
   ...:     )
   ...:     .show()
   ...: )

像往常一样,您首先定义 Plot 对象。这次您添加了一个 color 参数。但是,您不是指定实际颜色,而是定义day 数据系列。 这意味着添加的所有图层都会将绘图分为不同的日子,每一天都有不同的颜色。这在概念上与您之前看到的 hue 参数类似,但是 hue 并不存在于 Plot 中。

您决定使用 Bar 对象来表示您的数据,但在这种情况下,这些对象本身还不够。

要显示绘制的每个温度条上的中值,您需要将 Agg 对象添加到与 Bar 相同的图层中。这是 Stat 类型的示例,允许您指定在绘制数据之前如何转换或计算数据。 在此示例中,您传入 "median" 作为其 func 参数,该参数告诉它对每个 Bar 对象使用中值。默认值为“mean”

默认情况下,每个栏都会显示在彼此的顶部。要分离它们,您还需要在图层中添加一个 Dodge 对象。 这是 Move 对象类型的示例,允许您调整不同条形的位置。在本例中,您可以通过传递 gap=0.1 将每个条形之间设置间隙。

最后,使用 .label() 方法指定绘图的标签。通过设置 color="Day",您可以为图例标题指定一个大写字符串。

您的结果图如下所示:

正如您所看到的,每个月的数据由单独的条形簇表示,每个簇中的每个条形代表不同的一天。如果仔细观察,您会发现每个条形也与其他条形稍微分开。

将图分成子图

现在假设您希望每个月图出现在单独的子图上。为此,您可以使用 Plot 对象的 .facet() 方法来决定如何分离数据:

In [3]: import pandas as pd
   ...: import seaborn.objects as so
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: (
   ...:     so.Plot(
   ...:         data=crossings, x="month",
   ...:         y="max_temp", color="day",
   ...:     )
   ...:     .facet(col="month")
   ...:     .add(
   ...:         so.Bar(),
   ...:         so.Agg(func="median"),
   ...:         so.Dodge(gap=0.1)),
   ...:     .label(
   ...:         x=None, y="Temperature",
   ...:         title="Median Temperature", color="Day",
   ...:     )
   ...:     .show()
   ...: )

这次,当您在 Plot 对象上调用 .facet(col="month") 时,每个月度的数字都会被分离出来:

正如您所看到的,更新后的图现在显示三个子图,每个子图都有不同月份的数据。再次,对代码进行细微调整可以产生显着不同的输出。

使用您到目前为止学到的原理和seaborn文档,您可能想尝试以下练习:

任务 1:使用对象重新绘制您在文章开头创建的 min_temperencemax_Temperature 散点图。另外,请确保每个标记根据其代表的日期具有不同的颜色。 最后,使用星号来代表每个标记。

任务 2:使用显示四座桥梁中每座桥梁的最大和最小桥梁交叉口的对象创建条形图。

任务 3:使用分析早餐谷物卡路里计数的对象创建条形图。卡路里应该放入十个大小相等的箱子中。

任务 1 解决方案 使用对象重新绘制初始散点图的一种方法可能是:

In [1]: import pandas as pd
   ...: import seaborn.objects as so
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: (
   ...:     so.Plot(
   ...:         data=crossings, x="min_temp",
   ...:         y="max_temp", color="day",
   ...:     )
   ...:     .add(so.Dot(marker="*"))
   ...:     .label(
   ...:         x="Minimum Temperature",
   ...:         y="Maximum Temperature",
   ...:         title="Scatterplot of Temperatures",
   ...:     )
   ...:     .show()
   ...: )

首先,您将数据读入 DataFrame,然后将其连同您感兴趣的数据的列一起传递给 Plot 对象的构造函数。在本例中,您分配 " min_temp""max_temp" 分别为 xy 参数。

您可以通过为每个 xy 值对添加一个 Dot 对象来创建散点图的内容。要使每个点显示为星形,请传入 marker="*"。最后,您使用 .label() 为绘图提供标题以及每个轴的标签。

任务 2 解决方案 创建显示每座桥梁的最大和最小桥梁交叉口的条形图的一种方法是:

In [2]: import pandas as pd
   ...: import seaborn.objects as so
   ...:
   ...: crossings = pd.read_csv("cycle_crossings_apr_jun.csv")
   ...:
   ...: bridge_crossings = crossings.melt(
   ...:     id_vars=["day", "date"],
   ...:     value_vars=[
   ...:         "Brooklyn", "Manhattan",
   ...:         "Williamsburg", "Queensboro",
   ...:     ],
   ...:     var_name="Bridge",
   ...:     value_name="Crossings",
   ...: )
   ...:
   ...: (
   ...:     so.Plot(
   ...:         data=bridge_crossings, x="Bridge",
   ...:         y="Crossings", color="Bridge",
   ...:     )
   ...:     .add(so.Bar(), so.Agg("max"))
   ...:     .add(so.Bar(), so.Agg("min"))
   ...:     .label(
   ...:         x="Bridge", y="Crossings",
   ...:         title="Bridge Crossings",
   ...:     )
   ...:     .show()
   ...: )

您再次使用 .melt() 重构桥数据,然后将其与 "Bridge" 一起传递到 Plot 对象的构造函数中和您感兴趣的 "Crossings" 数据。要构建绘图的内容,您需要添加两对 BarAgg 对象,一个生成最大值条,另一个生成最小值条。最后,使用 .label() 添加一些标题。

任务 3 解决方案 创建分析早餐谷物卡路里计数的条形图的一种方法是:

In [3]: import pandas as pd
   ...: import seaborn.objects as so
   ...:
   ...: cereals_data = pd.read_csv("cereals_data.csv")
   ...:
   ...: (
   ...:     so.Plot(data=cereals_data, x="calories")
   ...:     .add(so.Bar(), so.Hist(bins=10))
   ...:     .label(
   ...:         x="Calories", y="Count",
   ...:         title="Calorie Counts",
   ...:     )
   ...:     .show()
   ...: )

首先,您将数据读入 DataFrame,然后将其与您感兴趣的数据列一起传递到 Plot 对象的构造函数中。在本例中,您设置 x= “卡路里”。条形图的内容是使用 Bar 对象创建的,但您还必须提供 Hist 对象来指定所需的 bin 数量。和以前一样,您添加一些标题并为每个轴添加标签。

尽管您可能不这么认为,但您实际上并没有到达海上旅程的终点,而只是其起点的终点。请记住,seaborn 仍在不断成长,因此总有更多东西需要您学习。本教程的主要重点是了解seaborn的关键原则。您必须理解这些,因为您以后可以以多种方式应用它们来生成非常复杂的绘图。

为什么不重新审视一下您在本教程中完成的各种任务,并使用文档来查看是否可以增强它们?另外,不要忘记,seaborn 的作者免费提供了大量示例数据集,让您可以练习、练习

结论

您现在已经掌握了seaborn 的基础知识。 Seaborn 是一个允许您创建数据统计分析可视化的库。凭借其双 API 和 Matplotlib 基础,它允许您生成各种不同的绘图来满足您的需求。

在本教程中,您学习了:

  • 如何确定可以考虑将seaborn与Python结合使用的情况
  • 如何使用seaborn的函数式接口通过Python来可视化数据
  • 如何使用seaborn的对象接口通过Python来可视化数据
  • 如何使用两个界面创建几种常见的绘图类型
  • 如何通过阅读文档来保持您的技能处于最新状态

有了这些知识,您现在就可以开始在 Python 代码中创建精美的 seaborn 数据可视化,以向其他人展示您分析的数据。

相关文章