[PYTHON] Comparison of fitting programs


In order to write an analysis procedure manual and a sample program for an undergraduate student's experiment, I tried to find out how to fit appropriate data by the least squares method using various programs, so I will leave a note of it.

Since the principle is the same in terms of using the least squares method, it is a problem if you do not get the same result, but unfortunately I knew from before that the answer differs depending on the program, but recently the cause and remedy I found the text that says, so I decided to organize it for myself. This is the text that triggered it. Peter Young, "Everything you wanted to know about Data Analysis and Fitting but were afraid to ask" Apparently, it seems to be a lecture material of a university. Here, Gnuplot and Python's `` `scipy.optimize.curve_fit``` says that if there is an error in the data, the error value attached to the resulting parameter is incorrect and needs to be corrected.

In the case of Gnuplot, it is necessary to modify it like this.

to get correct error bars on fit parameters from gnuplot when there are error bars on the points, you have to divide gnuplot’s asymptotic standard errors by the square root of the chi-squared per degree of freedom (which gnuplot calls FIT STDFIT and, fortunately, computes correctly).

In the case of python's `scipy.optimize```, it is complicated, curve_fit``` needs to be modified, but `` leastsq``` does not need to be modified.

I recently learned that error bars on fit parameters given by the routine curve_fit of python also have to be corrected in the same way. This is shown in two of the python scripts in appendix H. Curiously, a different python fitting routine, leastsq, gives the error bars correctly.

It says that I wonder why no one is aware of this.

It is curious that I found no hits on this topic when Googling the internet.

Sample data

I prepared such a data set. 20 sets of (x, y, ey). I will fit this in a straight line. I will save it as a text file `` `data.txt```.

0.0e+00 -4.569987720595017344e-01 1.526828747143463172e+00
1.0e+00 -7.106162255269843353e-01 1.402885069270964458e+00
2.0e+00 1.105159634902675325e+00 1.735638554786020915e+00
3.0e+00 -1.939878950652441869e+00 1.011014634823069747e+00
4.0e+00 3.609690931525689983e+00 1.139915698020605550e+00
5.0e+00 8.535035219721383015e-01 9.338187791237286817e-01
6.0e+00 4.770810591544029755e+00 1.321364026236713451e+00
7.0e+00 3.323982457761388787e+00 1.703973901689593173e+00
8.0e+00 3.100622722027332578e+00 1.002313080286136637e+00
9.0e+00 4.527766245564444070e+00 9.876090792441625243e-01
1.0e+01 1.990062497396323682e+00 1.355607177365929505e+00
1.1e+01 5.113013340421659336e+00 9.283045349565146598e-01
1.2e+01 4.391676777018354905e+00 1.337677147217683160e+00
1.3e+01 5.388022504497612886e+00 9.392443558621643707e-01
1.4e+01 1.134921361159764075e+01 9.232583484294124565e-01
1.5e+01 6.067025020573844074e+00 1.186258237028150475e+00
1.6e+01 1.052771612360148445e+01 1.200732350014090954e+00
1.7e+01 6.221953870216905713e+00 8.454085761899273743e-01
1.8e+01 9.628358150028700990e+00 1.442970173161927772e+00
1.9e+01 9.493784288063746857e+00 8.196526623903285236e-01

Sample code in various programs


This is ROOT. https://root.cern.ch The code was the shortest. In the case of ROOT, if it is simple, it is prepared without defining a fit function.


  TGraphErrors *g = new TGraphErrors("data.txt","%lg %lg %lg");

The result looks like this. スクリーンショット 2020-07-16 17.51.40.png


Gnuplot is also good at fitting with functions, so the required code is very short. I think the problem with gnuplot is that the default output looks too bad.


set fit errorvariables

f(x) = p0 + p1*x
fit f(x) "data.txt" u 1:2:3 via p0,p1

plot "data.txt" u 1:2:3 w yerr, f(x)

print "\n ====== ERROR CORRECTED ========"
print "Chi2/NDF = ",FIT_STDFIT**2 * FIT_NDF,"/",FIT_NDF
print "  p0 = ",p0," +- ",p0_err/FIT_STDFIT
print "  p1 = ",p1," +- ",p1_err/FIT_STDFIT

# ---------The following is the appearance adjustment of the figure--------

set term qt font "Helvetica"
set xrange [-1:20]
set rmargin 5
set tics font ",16"
set key font ",16"
set key left top
set bars fullwidth 0
set style line 1 lc "#0080FF" lw 1 pt 7 ps 1
set style line 2 lc "#FF3333" lw 2 pt 0 ps 1

set label 1 at first 1,11 sprintf("Chi2/ndf = %5.2f / %2d",FIT_STDFIT**2 * FIT_NDF,FIT_NDF) font ",18"
set label 2 at first 1,10 sprintf("p0 = %6.3f +- %7.4f",p0,p0_err/FIT_STDFIT) font ",18"
set label 3 at first 1,9  sprintf("p1 = %6.4f +- %7.5f",p1,p1_err/FIT_STDFIT) font ",18"

plot "data.txt" u 1:2:3 w yerr ls 1,\
     f(x) ls 2

The "Asymptotic Standard Error" written in the standard output is incorrect and needs to be corrected. Specifically, divide the error value by a variable called FIT_STDFIT, as in the code above. If you write `` `set fit errorvariables``` at the beginning, you can also pick up the error value with the variable name _err. If you modify it, the same value as ROOT will appear.

Final set of parameters            Asymptotic Standard Error
=======================            ==========================
p0              = -1.06859         +/- 0.9578       (89.64%)
p1              = 0.566268         +/- 0.07983      (14.1%)

correlation matrix of the fit parameters:
                p0     p1     
p0              1.000 
p1             -0.884  1.000 

 ====== ERROR CORRECTED ========
Chi2/NDF = 59.1533703771407/18
  p0 = -1.06858871709936 +- 0.528376469987239
  p1 = 0.566267669300731 +- 0.0440357299923021

スクリーンショット 2020-07-16 18.28.38.png

One thing to keep in mind is that this error should only be fixed if there is an error in the data points (when you specify three columns after using when picking up data with the fit command). What you need to do. Do not make this correction if all fit with the same weight (= data without error).

python scipy.optimize.curve_fit

Loading the data was easy with numpy.loadtxt.


#Read data
import numpy as np

data = np.loadtxt("data.txt")
xx = data.T[0] 
yy = data.T[1]
ey = data.T[2]

#Define a fitting function ff
def ff(x,a,b):
    return a + b*x

#Fit and view results
from scipy.optimized import curve_fit
import math

par, cov = curve_fit(ff,xx,yy,sigma=ey)

chi2 = np.sum(((func_pol1(xx,par[0],par[1])-yy)/ey)**2)
print("chi2 = {:7.3f}".format(chi2))
print("p0 : {:10.5f} +- {:10.5f}".format(par[0],math.sqrt(cov[0,0]/chi2*18)))
print("p1 : {:10.5f} +- {:10.5f}".format(par[1],math.sqrt(cov[1,1]/chi2*18)))

#Display on the graph
import matplotlib.pyplot as plt

x_func = np.arange(0,20,0.1)
y_func = par[0] + par[1]*x_func


It seems that Gnuplot's FIT_STDFIT is not provided, so I will calculate Chi2 and NDF by myself and calculate the parameter error using the diagonal component of the output covariance matrix. If you calculate it properly, you will get the correct value.

chi2 =  59.153
p0 :   -1.06859 +-    0.52838
p1 :    0.56627 +-    0.04404

スクリーンショット 2020-07-16 18.59.27.png

python scipy.optimize.leastsq

I've never used this, so https://qiita.com/yamadasuzaku/items/6d42198793651b91a1bc I was allowed to refer to. It was a little confusing (not Chi ^ 2) that what I had to prepare was Chi, not the function I wanted to fit.


#Read data
import numpy as np

data = np.loadtxt("data.txt")
xx = data.T[0] 
yy = data.T[1]
ey = data.T[2]

#Define Chi
from scipy.optimize import leastsq
import math

def chi(prm,x,y,ey):
    return (((prm[0]+prm[1]*x)-y)/ey)

#Prepare the initial value and fit
init_val = (-0.5, 0.5)

prm, cov, info, msg, ier = leastsq(chi,init_val,args=(xx,yy,ey),full_output=True)

chi2 = np.sum((((prm[0]+prm[1]*xx) - yy)/ey)**2)

print("chi2 = {:7.3f}".format(chi2))
print("p0 : {:10.5f} +- {:10.5f}".format(prm[0],math.sqrt(cov[0,0])))
print("p1 : {:10.5f} +- {:10.5f}".format(prm[1],math.sqrt(cov[1,1])))

The graph display is the same as above, so it is omitted. The result is as follows. In the case of leastsq, no modification is required, so the square root of the diagonal component of the output covariance matrix can be used as it is.

chi2 =  59.153
p0 :   -1.06859 +-    0.52838
p1 :    0.56627 +-    0.04404

at the end

I used to think that the result of gnuplot might not match the manual calculation, but I usually use ROOT, so I didn't check it seriously, but I finally found out how to deal with it, so it was very refreshing. Personally, I'm used to it, so ROOT is easy, but I thought that Gnuplot or Python curve_fit would be easier to understand if I was teaching undergraduate students. However, both of them have the problem that the error that should be attached to the resulting parameter needs to be corrected.

By the way, I was thinking that it would be better to teach Python instead of C or Gnuplot to undergraduate students nowadays, so I thought that it would be better to teach Python as well as my own study. I tried. Sure, Python is good in that it can do everything from data processing to graph display, but when it comes to function definition and graph display, Gnuplot is more intuitive and specializes in drawing graphs. I also felt that there was only one. As a simple example, I will display almost the same thing, but if you compare the two below, I think Gnuplot is more intuitive. It looks better in Python though.


set xrange [0:10]
f(x) = sin(x)
plot f(x)


import numpy as np
import matplotlib.pyplot as plt
x = np.arange(0,10,0.1)
y = np.sin(x)

Recommended Posts

Comparison of fitting programs
Comparison of LDA implementations
Comparison of online classifiers
Basics of network programs?
Static analysis of Python programs
Comparison of 4 Python web frameworks
Comparison of Apex and Lamvery
Speed comparison of Python XML parsing
Comparison of 2020 Standalone DB Migration Tools
(Java, JavaScript, Python) Comparison of string processing
Comparison of Japanese conversion module in Python3
Comparison of gem, bundler and pip, venv
python string comparison / use'list'and'in' instead of'==' and'or'
Test of uniqueness in paired comparison method
Comparison of solutions in weight matching problems
Comparison of class inheritance and constructor description
Try speed comparison of BigQuery Storage API
Tips: Comparison of the size of three values
Comparison of Python serverless frameworks-Zappa vs Chalice
Comparison of L1 regularization and Leaky Relu
Comparison of matrix transpose speeds with Python
[Python] Chapter 02-03 Basics of Python programs (input / output)
Speed comparison of murmurhash3, md5 and sha1