seaborn画小提琴图简单,但缺乏自定义能力,画出来的图实在是不好看
matplotlib默认小提琴图不符合论文规范,官网提供了符合规范的图的画法,可调整性好
但是代码实在是有装逼犯的嫌疑,难懂难用难调整
我好不容易调好代码,3天之后就忘了那坨代码具体干了啥,需求一变,重新开始。
。
。
所以我干脆按官网的原理做了个简化版,把奇奇怪怪没必要的 *** 作全换了,调起来舒心很多,还封装了一个函数,方便调用
(代码在文末)
用了循环,不可避免牺牲一点效率,但只要不是同时画几百张图,这点差别感觉不出来
每个步骤干了啥都写在注释里了
顺便用官网的例子画个图,左边用我的函数画的,右边是官网代码画的
看不出差别吧,反正我看不出。
。
。
。
下面是我的函数和官网代码,复制这段代码跑一下,就可以得到上面那张图
import pandas as pd
import numpy as np
import matplotlib as mpl
#允许中文
mpl.rcParams['font.sans-serif']= ['SimHei']
#允许负号
mpl.rcParams['axes.unicode_minus'] = False
import matplotlib.pyplot as plt
###################################这一段是我的函数################################
def cal_box_whisker_scale(value_se):
# 中位,四分位数计算
quartile1,median,quartile3 = np.percentile(value_se,[25, 50, 75], axis=0)
# whisker的最大值计算:“75分位数+1.5*四分位间距”
whiskerMax = quartile3 + (quartile3 - quartile1) * 1.5
# np.clip教程参见 https://blog.csdn.net/qq1483661204/article/details/78150203
# 如果whisker的最大值>数据最大值,则设为最大值,若<75分位,设为75分位
whiskerMax = np.clip(whiskerMax, quartile3, value_se.max())
# whisker的最小值计算:“25分位数-1.5*四分位间距”
whiskerMin = quartile1 - (quartile3 - quartile1) * 1.5
# 如果whisker的最小值<数据最小值,则设为最小值,若>25分位,设为25分位
whiskerMin = np.clip(whiskerMin, value_se.min(), quartile1)
# 返回whisker最小值,25分位,中位,75分位,whisker最大值
return whiskerMin, quartile1, median, quartile3, whiskerMax
'''
参数依次是: 数据dataframe, 用于分组的列的列名,
要画在图上的组的列表(列表顺序即绘图顺序), 绘图值列的列名(纵坐标), 绘图轴
'''
def my_violin_plot(data_df, group_col, group_list, value_col, ax):
for abscissa,g in enumerate(group_list):
# 提取该组的数据, 注意, 这里的g必须是字符, 若是整数, 请去掉'{g}'两端的引号
data_se = data_df.query(f"{group_col}=='{g}'")[value_col]
# 画小提琴,不使用自带的中位数和须
parts = ax.violinplot(
data_se, [abscissa], showmeans=False, showmedians=False,
showextrema=False)
#虽然里面就一个形状, 但是返回的是包含形状对象的字典, 还是得提取出来
pc=parts['bodies'][0]
# 设置形状填充色
pc.set_facecolor('#D43F3A')
# 设置边框色
pc.set_edgecolor('black')
# 设置透明度
pc.set_alpha(1)
# 使用前面定义的函数计算whisker的最小值,25分位,中位,75分位,whisker最大值
whiskerMin,quartile1,medians,quartile3,whiskerMax = cal_box_whisker_scale(data_se)
# 描中位数点
ax.scatter(abscissa, medians, marker='o', color='white', s=30, zorder=3)
# 画箱体
ax.vlines(abscissa, quartile1, quartile3, color='k', linestyle='-', lw=5)
# 画whisker
ax.vlines(abscissa, whiskerMin, whiskerMax, color='k', linestyle='-', lw=1)
##############################下面的是官网代码###################################
def adjacent_values(vals, q1, q3):
upper_adjacent_value = q3 + (q3 - q1) * 1.5
upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])
lower_adjacent_value = q1 - (q3 - q1) * 1.5
lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
return lower_adjacent_value, upper_adjacent_value
def set_axis_style(ax, labels):
ax.xaxis.set_tick_params(direction='out')
ax.xaxis.set_ticks_position('bottom')
ax.set_xlabel('Sample name')
# create test data
np.random.seed(19680801)
data = [sorted(np.random.normal(0, std, 100)) for std in range(1, 5)]
dataDf = pd.DataFrame([(str(n+1),d) for n,dat in enumerate(data) for d in dat], columns=['横坐标','纵轴值'])
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4), sharey=True)
ax1.set_title('我的函数')
ax1.set_ylabel('Observed values')
my_violin_plot(dataDf, '横坐标', [1,2,3,4], '纵轴值', ax1)
ax2.set_title('官网方法')
parts = ax2.violinplot(
data, showmeans=False, showmedians=False,
showextrema=False)
for pc in parts['bodies']:
pc.set_facecolor('#D43F3A')
pc.set_edgecolor('black')
pc.set_alpha(1)
quartile1, medians, quartile3 = np.percentile(data, [25, 50, 75], axis=1)
whiskers = np.array([
adjacent_values(sorted_array, q1, q3)
for sorted_array, q1, q3 in zip(data, quartile1, quartile3)])
whiskers_min, whiskers_max = whiskers[:, 0], whiskers[:, 1]
inds = np.arange(1, len(medians) + 1)
ax2.scatter(inds, medians, marker='o', color='white', s=30, zorder=3)
ax2.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=5)
ax2.vlines(inds, whiskers_min, whiskers_max, color='k', linestyle='-', lw=1)
# set style for the axes
labels = ['A', 'B', 'C', 'D']
for ax in [ax1, ax2]:
set_axis_style(ax, labels)
ax1.set_xlim(-0.75, len(labels) - 0.25)
ax1.set_xticks(np.arange(0, len(labels)), labels=labels)
ax2.set_xlim(0.25, len(labels) + 0.75)
ax2.set_xticks(np.arange(1, len(labels) + 1), labels=labels)
plt.subplots_adjust(bottom=0.15, wspace=0.05)
plt.savefig(r'C:\Users\WIN10\Desktop.png', dpi=300)
plt.show()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)