numpyで2つの三次元座標点列の引き当て
numpyで2つの三次元座標点列の引き当てをできるだけ高速に行いたいのですが、行き詰まっています。
<前提>
numpy
のround()
とbuiltins
のround()
は挙動が異なります。
(これ自体は質問の主旨ではありませんが、これについても解説いただけると大変ありがたいです。)
>>> print(round(411105.886185, 5))
411105.88619
>>> print(np.round(411105.886185, 5))
411105.88618
<本題>
ある3次元座標点列 V
をbuiltins
のround()
を使って小数点以下6桁で記録したa.tsvと
V
から一部を抽出してnumpy
のround()
で小数点以下9桁で記録したb.tsvがあります。
a.tsv、b.tsvともに元々の V
の順番通りに記録されています。
b.tsvの各行が、a.tsvでは何行目になるか、を特定したいです。
a.tsv
977279.482707 734066.643064 662406.439074
627635.945559 451974.042893 929737.099191
1025463.393349 752819.302836 885502.793725
971104.800369 916731.879454 475093.855238
382780.576043 307121.604863 661611.845153
...
b.tsv
977279.482707500 734066.643064236 662406.439073611
971104.800368500 916731.879453945 475093.855237700
382780.576043500 307121.604863451 661611.845152960
...
以下のコードで目的は達成できるのですが、計算量が log(N**2)
となる点と、メモリ効率が悪い点が納得いきません。
import numpy as np
a = np.genfromtxt('a.tsv', delimiter='\t')
b = np.genfromtxt('b.tsv', delimiter='\t')
D = b-a[:, np.newaxis]
S = np.sum(D**2, axis=2)
I = np.argmin(S, axis=0)
print(I)
最初に a
を辞書型に格納すれば O(NlogN)
になると思ったのですが、先述の round
の問題のため上手くいきませんでした。
何卒ご教示お願いいたします。