02 February 2021

(Python) Matplotlibでグラフ作成

PythonでMatplotlibを用いてグラフを作成する基本サンプル

matplotlib.pyplotに直接描画するMATLAB方式で単純なグラフを作成

MATLABで用いられていた流儀、pyplotにplotで直接グラフを描いていく方法。 設定項目が少ないので、あまりお薦めではないが、簡便にグラフを描くには最適

20210202-simple-plot.jpg

#!/usr/bin/python
# -*- coding: utf-8 -*-
 
import matplotlib.pyplot as plt
import numpy as np
 
# x軸の0.0から5.0を100等分
#x = np.linspace(0, 5, 100)
# x軸の0.0から5.0まで0.05間隔で
x = np.arange(0, 5, 0.05)
 
y1 = np.sin(x * 2 * np.pi) * np.exp(-x)
y2 = np.sin(x * 2 * np.pi)
 
plt.plot(x, y1, color='red', label='y=sin(2pi*x)*exp(-x)')
plt.plot(x, y2, color='darkgreen', linewidth=0.5, linestyle='--')
 
plt.grid(color='#bbbbff', linestyle='dashed')
plt.title('Sample Graph')
plt.legend(loc='upper right')
 
plt.show()

numpy.linspace
numpy.arange の2つのメソッドは、どちらも等差数列を作るのだが、要素数の与え方に注意が少しだけ必要。

たとえば、0 から 5 を10等分した等差数列を作る場合は、次のように「赤で着色した」部分に注意。

>>> np.arange(0, 5+0.5, 0.5)
array([ 0. ,  0.5,  1. ,  1.5,  2. ,  2.5,  3. ,  3.5,  4. ,  4.5,  5. ])
 
>>> np.linspace(0, 5, 10+1)
array([ 0. ,  0.5,  1. ,  1.5,  2. ,  2.5,  3. ,  3.5,  4. ,  4.5,  5. ])

matplotlib.pyplot内にaxesオブジェクトを作成し、そこにグラフを描画する

こちらのほうが、軸やグラフの細かな設定ができる。

20210202-axes-plot.jpg

#!/usr/bin/python
# -*- coding: utf-8 -*-
 
import matplotlib.pyplot as plt
import numpy as np
 
# x軸の0.0から5.0を100等分
#x = np.linspace(0, 5, 100)
# x軸の0.0から5.0まで0.05間隔で
x = np.arange(0, 5, 0.05)
 
y1 = np.sin(x * 2 * np.pi) * np.exp(-x)
y2 = np.sin(x * 2 * np.pi)
 
fig, ax = plt.subplots()
 
ax.plot(x, y1, color='red', label='y=sin(2pi*x)*exp(-x)')
ax.plot(x, y2, color='darkgreen', linewidth=0.5, linestyle='--')
 
ax.set_xticks(np.linspace(0,2,5))
ax.set_xticks(np.linspace(0,2,21) ,minor=True)
ax.grid(color='#bbbbff', linestyle='dashed', which="major")
ax.grid(color='#ddddff', linestyle='dotted', which="minor")
ax.set_title('Sample Graph')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend()
 
plt.show()

2つの曲線の交点を求める場合

y = x^2 - 1 と、 y = x + 1 の交点は、[x, y] = [-1, 0] と [2, 3] であることを示すグラフを作成するスクリプトは

20210202-crosspoint.jpg

#!/usr/bin/python
# -*- coding: utf-8 -*-
 
import matplotlib.pyplot as plt
import numpy as np
 
x1 = np.arange(-2.5, 2.5+0.1, 0.1)
y1 = x1**2 - 1.0
y2 = x1 + 1
 
# 交点の数列index No.を求める
crosspoint_x = np.argwhere(np.sign(np.round(y1 - y2, 3)) == 0)
 
fig, ax = plt.subplots()
 
ax.plot(x1, y1, color='red', linewidth='1', label='y = x^2 - 1')
ax.plot(x1, y2, color='blue', linewidth='1', label='y = x + 1')
ax.plot(x1[crosspoint_x], y1[crosspoint_x], 'o',  color='black')
 
ax.set_xticks(np.arange(-2,2+1,1))
ax.set_xticks(np.arange(-2.5,2.5+0.5,0.5) ,minor=True)
ax.grid(color='#bbbbff', linestyle='dashed', which="major")
ax.grid(color='#ddddff', linestyle='dotted', which="minor")
ax.set_title('Sample Graph')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend()
 
plt.show()

交点は、双方の数列の差分を計算して、その符号(numpy.sign)を計算し、それがゼロとなるリスト番号を抽出する(numpy.argwhere)処理をする。 求められたリスト番号で指し示される[x,y]が、2つの曲線の差分がゼロとなる値(つまり、交点)を見つけ出すという処理をしている。

>>> x1 = np.arange(-2.5, 2.5+0.5, 0.5)
>>> x1
array([-2.5, -2. , -1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5])
 
>>> y1 = x1**2 - 1.0
>>> y1
array([ 5.25,  3.  ,  1.25,  0.  , -0.75, -1.  , -0.75,  0.  ,  1.25,
        3.  ,  5.25])
 
>>> y2 = x1 + 1
>>> y2
array([-1.5, -1. , -0.5,  0. ,  0.5,  1. ,  1.5,  2. ,  2.5,  3. ,  3.5])
 
>>> np.sign(np.round(y1 - y2, 3))
array([ 1.,  1.,  1.,  0., -1., -1., -1., -1., -1.,  0.,  1.])
 
>>> crosspoint_x = np.argwhere(np.sign(np.round(y1 - y2, 3)) == 0)
>>> crosspoint_x
array([[3],
       [9]])

ソースコードを参考にしたwebサイトは、

crosspoint_x = np.argwhere(np.sign(y1 - y2) == 0)

というコードを用いていたが、浮動小数点の誤差範囲があり、完全にゼロに鳴らない場合もあるので、

crosspoint_x = np.argwhere(np.sign(np.round(y1 - y2, 3)) == 0)

というふうに改変している。

SciPiで補間曲線を描く

20210202-interpolate.jpg

scipy.interpolate.interp1d を使って、補間曲線のデータを作成する。

kind パラメータに指定できるのは、linear = 線形補間、sliner = 一次スプライン補間、quadratic = 二次スプライン補間、cubic = 三次スプライン補間

#!/usr/bin/python
# -*- coding: utf-8 -*-
 
import matplotlib.pyplot as plt
import numpy as np
from scipy import interpolate
 
# 粗い sin 曲線(x軸20分割のデータを作成)
x = np.linspace(0, 6, 20)
y = np.sin(x*3.14)
 
# 補完曲線(データ補完して、x軸100分割のデータとする)
xn = np.linspace(0, 6, 100)
f = interpolate.interp1d(x, y, kind='cubic')
yn = f(xn)
 
fig, ax = plt.subplots()
# 粗い曲線(データがある部分の「点」のみ表示)
ax.plot(x, y, 'o', label='original data')
# 補完曲線
ax.plot(xn, yn, color='red', label='interpolate')
ax.grid()
ax.legend(loc='upper right')
 
plt.show()

日本語フォントの利用とCSVデータからの散布図の作成

20210202-scatter-plot.jpg

Python 3 と Python 2 の双方で実行可能な、日本語文字列を表示するスクリプトを作成してみた

#!/usr/bin/python
# -*- coding: utf-8 -*-
 
import matplotlib.pyplot as plt
import numpy as np
from scipy import interpolate
 
### 日本語フォントの設定
import matplotlib.font_manager as fm
from matplotlib import rcParams
fm._rebuild()
font_path = "/usr/share/fonts/truetype/mplus-TESTFLIGHT-063a/mplus-1m-medium.ttf"
font_prop = fm.FontProperties(fname=font_path)
plt.rcParams["font.family"] = font_prop.get_name()
### 日本語フォントの設定(ここまで)
 
# CSVファイルから、データを読み込む
data = np.genfromtxt('sanpuzu.csv', skip_header=1, delimiter=',', dtype=['U20', 'float', 'int'], usecols=[1,2,3], converters={1: lambda x: x.decode('utf_8')})
 
fig, ax = plt.subplots()
 
# 「点」のプロット
ax.plot(data['f1'], data['f2'], 'o')
# それぞれの点に対する「ラベル」
for i in range(data.size):
    ax.annotate(data[i][0], xy=(data[i][1], data[i][2]), fontsize='small', color='maroon')
 
ax.grid()
ax.set_title(u'県庁所在地ごとの物価と平均給与の関係')
ax.set_xlabel(u'平均消費者物価地域差指数')
ax.set_ylabel(u'現金給与総額')
 
plt.show()

読み込むCSVデータは次のようなもの

都道府県,県庁所在地,平均消費者物価地域差指数,現金給与総額
北海道,札幌市,99.6,275093
青森,青森市,98.6,230307
岩手,盛岡市,99.4,252656
宮城,仙台市,99.2,278847
秋田,秋田市,98.2,243014
山形,山形市,99.4,253131
福島,福島市,100.3,277150
茨城,水戸市,98.6,301913
栃木,宇都宮市,99.2,291574
群馬,前橋市,96.4,273227
埼玉,さいたま市,102.8,269827
千葉,千葉市,101.1,281349
東京,東京都区部,105.1,379858
神奈川,横浜市,105.1,314231
新潟,新潟市,98.9,258607
富山,富山市,99.5,273676
石川,金沢市,100.3,271260
福井,福井市,99.3,279745
山梨,甲府市,99.4,264987
長野,長野市,97.5,276071
岐阜,岐阜市,98.1,272856
静岡,静岡市,99.2,299973
愛知,名古屋市,98.9,312656
三重,津市,98.2,299360
滋賀,大津市,100.4,295613
京都,京都市,100.9,278542
大阪,大阪市,99.9,304025
兵庫,神戸市,101.2,291960
奈良,奈良市,96.7,257034
和歌山,和歌山市,99.8,272216
鳥取,鳥取市,98.3,245547
島根,松江市,99.8,259231
岡山,岡山市,98.5,290686
広島,広島市,98.9,290240
山口,山口市,98.5,264835
徳島,徳島市,100.2,278432
香川,高松市,98.9,265377
愛媛,松山市,98,257662
高知,高知市,99.2,290196
福岡,福岡市,97,288308
佐賀,佐賀市,96.9,251901
長崎,長崎市,101.2,259043
熊本,熊本市,98.4,261999
大分,大分市,98,266460
宮崎,宮崎市,96.8,250215
鹿児島,鹿児島市,97.2,250803
沖縄,那覇市,99.2,244571

numpy.genfromtxt を使ってCSVテキストファイルを読み込むスクリプトは次のようなもので

data = np.genfromtxt('sanpuzu.csv', skip_header=1, delimiter=',', dtype=['U20', 'float', 'int'], usecols=[1,2,3], converters={1: lambda x: x.decode('utf_8')})

CSVの幾つめのカラムを読み込むかは usecols=[1,2,3] のように指定し、カラムは0(ゼロ)から始まり、 0, 1, 2, 3 ... と数えられる。 また、読み取ったデータは型変換されて格納されるが、その型は dtype=['U20', 'float', 'int'] のように指定する。

dtypeには numpy.dtype の名前または Python の型名を指定する。詳しくは『NumPyのデータ型dtype一覧とastypeによる変換(キャスト)』、『NumPy : Data type objects (dtype)』などが参考になる。

よく使う例として、こういうものがある。

整数(8ビット)int8i1
整数(16ビット)int16i2
整数(32ビット)int32i4
符号なし整数(8ビット)uint8u1
浮動小数点(16ビット)float16f2
浮動小数点(32ビット)float32f4
文字列(Unicode)unicodeU