跳到主要内容
Jason Lee
The owner of this blog site

I'm a programmer who loves to learn new technologies and build cool stuff. I'm currently working as a software engineer at a small company in Shanghai. Hope to find some connections with you!

View All Authors

matplotlib最佳实践

· 阅读需 5 分钟
Jason Lee
The owner of this blog site

前言

跟着油管上的教程学的,作者介绍了一些最佳实践,融合一些自己的思考,把觉得有用的东西记录一下。

part1 折线图

# 导包区
from matplotlib import pyplot as plt
# plt.style.use('seaborn') # 切换图片风格
plt.xkcd() # 手绘风格

# 数据区
ages_x = [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]
dev_y = [38496, 42000, 46752, 49320, 53200,
56000, 62316, 64928, 67317, 68748, 73752]
py_dev_y = [45372, 48876, 53850, 57287, 63016,
65998, 70003, 70000, 71496, 75370, 83640]

# 绘图区
plt.plot(ages_x, dev_y, color='k', linestyle='--', label='All Devs')
plt.plot(ages_x, py_dev_y, color='b',marker='o', label='Python Devs')

plt.xlabel('Ages')
plt.ylabel('Median Salary')
plt.title('Median Salary By Age')
plt.legend()

plt.grid(True)

plt.show()

plot.plot()函数,一个对应一条曲线,而且,color, linestyle, marker 单独作为参数写,而不是混在一起;还有,把图例放在 plot 函数里而不是 legend(),都是为了增加可读性。

part2 柱状图

画柱状图的函数是 plt.bar()。但是如果将上述的 plt.plot()改成 plt.bar(),会出现重叠的情况。为了避免重叠,需要添加偏移。

ages_x = [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]
dev_y = [38496, 42000, 46752, 49320, 53200,
56000, 62316, 64928, 67317, 68748, 73752]
py_dev_y = [45372, 48876, 53850, 57287, 63016,
65998, 70003, 70000, 71496, 75370, 83640]
js_dev_y = [37810, 43515, 46823, 49293, 53437,
56373, 62375, 66674, 68745, 68746, 74583]

x_indexes = np.arange(len(ages_x))
width=0.25

width 是柱的宽度,柱的数量*width 应该小等于 1。

plt.bar(x_indexes-width, dev_y, color='k', width=width, label='All Devs')
plt.bar(x_indexes, py_dev_y, color='b', width=width, label='Python Devs')
plt.bar(x_indexes+width, js_dev_y, color='g',width=width, label='JavaScript Devs')

紧接着,作者给出了一个实例,从 csv 文件读取数据,并统计各个编程语言的使用人数。用的是 python 标准 csv 库进行读取,后续可以用 pandas。

with open('data.csv') as csv_file:
csv_reader = csv.DictReader(csv_file)
counter = Counter()

for row in csv_reader:
counter.update(row['LanguagesWorkedWith'].split(';'))

print(counter)
with open('data.csv') as csv_file:
csv_reader = csv.DictReader(csv_file)
counter = Counter()

for row in csv_reader:
counter.update(row['LanguagesWorkedWith'].split(';'))

print(counter)

Counter({'JavaScript': 59219, 'HTML/CSS': 55466, 'SQL': 47544, 'Python': 36443, 'Java': 35917, 'Bash/Shell/PowerShell': 31991, 'C#': 27097, 'PHP': 23030, 'C++': 20524, 'TypeScript': 18523, 'C': 18017, 'Other(s):': 7920, 'Ruby': 7331, 'Go': 7201, 'Assembly': 5833, 'Swift': 5744, 'Kotlin': 5620, 'R': 5048, 'VBA': 4781, 'Objective-C': 4191, 'Scala': 3309, 'Rust': 2794, 'Dart': 1683, 'Elixir': 1260, 'Clojure': 1254, 'WebAssembly': 1015, 'F#': 973, 'Erlang': 777})

画图用 bar 函数的话,x 轴就会很拥挤。所以采用横柱状图 plt.barh(),将 x,y 颠倒。

part3 饼状图

饼状图一般适合五种以下的数据展示。

slices = [59219, 55466, 47544, 36443, 35917]
labels = ['JavaScript', 'HTML/CSS', 'SQL', 'Python', 'Java']
explode = [0, 0, 0, 0.1, 0]

plt.pie(slices, labels=labels, explode=explode,startangle=90, autopct='%1.1f%%', wedgeprops={'edgecolor': 'black'})

## wedgeprops: 楔子属性?反正就是设置图表的一些性质

plt.title('My Awesome Pie Chart')
plt.show()

part4 堆积面积图

堆积面积是 stackplot 的直译,但似乎不是很妥当?api 是 plt.stackplot()

part5 折现填充图

在折现的上方或下方进行填充,直观地表示差距。

plt.plot(ages, py_salaries, label='Python')

part6 直方图

bins 是每根柱代表的数据间距, log 表示将数据取对数

plt.hist(ages, bins=bins, edgecolor='black', log=True)

part7 散点图

plt.scatter(views, likes, c=ratio, cmap='summer', edgecolor='black', linewidth=1, alpha=0.75)
cbar = plt.colorbar()
cbar.set_label('Like/Dislike Ratio')

plt.xscale('log')
plt.yscale('log')

plt.title('Median Salary By Age')
plt.xlabel('View Count')
plt.ylabel('Total Likes')
plt.show()

c 是指颜色的深浅度,cmap 设置颜色的系列,plt.colorbar()在图中添加彩色条,plt.xscale('log')将刻度改为对数,防止过大的数远离图表中心。

part8 时序图

data['Date'] = pd.to_datetime(data['Date']) ##把字符串解析为日期格式
data.sort_values('Date', inplace=True)

price_date = data['Date']
price_close = data['Close']

plt.plot_date(price_date, price_close, linestyle='solid')

plt.gcf().autofmt_xdate() ## 自动格式化日期

part9 动画

这是和深度学习相关的一个场景,比如损失函数的显示就是动态的。


import random
from itertools import count
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

plt.style.use('fivethirtyeight')

x_vals = []
y_vals = []

# plt.plot(x_vals, y_vals)


index = count()

def animate(i):
x_vals.append(next(index))
y_vals.append(random.randint(0, 5))

plt.cla()
plt.plot(x_vals, y_vals)

ani = FuncAnimation(plt.gcf(), animate, interval=1000)
plt.tight_layout()
plt.show()

但是上述代码在 jupyter 中无法正常运行,只能用 pycharm 来输出。

part10 子图

plt.style.use('seaborn')

data = pd.read_csv('data.csv')
ages = data['Age']
dev_salaries = data['All_Devs']
py_salaries = data['Python']
js_salaries = data['JavaScript']

fig1, ax1 = plt.subplots()
fig2, ax2 = plt.subplots()

ax1.plot(ages, dev_salaries, color='#444444',
linestyle='--', label='All Devs')

ax2.plot(ages, py_salaries, label='Python')
ax2.plot(ages, js_salaries, label='JavaScript')

ax1.legend()
ax1.set_title('Median Salary (USD) by Age')
ax1.set_ylabel('Median Salary (USD)')

ax2.legend()
ax2.set_xlabel('Ages')
ax2.set_ylabel('Median Salary (USD)')

plt.tight_layout()

plt.show()

fig1.savefig('fig1.png')
fig2.savefig('fig2.png')

首先,子图中最重要的概念是 figure 和 axis。之前的代码都是针对 figure(默认),figure 可以认为是一个容器,表现为窗口;axis 就是一张具体的图。