[PYTHON] [TF] How to specify variables to update with Optimizer

When you create and train a model in Deep Learning, you may want to specify the parameters to learn. For example, when you are learning two networks alternately and you do not want to update the parameters of one network.

There are several ways to specify the parameters to learn.

1. 1. Set trainable of Variable argument to False

python


x = tf.Variable(tf.constant([2.]), name='x', trainable=False)

2. Pass the list of variables to update to Optimizer

python


opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f, var_list=[w,b])

Number 2 is easier to change the code.

Execution result

When you do nothing

When I execute the code below, the variables w, b, x, y_ are updated. (The code below is not practical, and usually x and y_ are used as placeholders, so it doesn't matter, but please forgive me as an example.)

python


import tensorflow as tf
import numpy as np

w = tf.Variable(tf.constant([3.]), name='w')
b = tf.Variable(tf.constant([1.]), name='b')
x = tf.Variable(tf.constant([2.]), name='x')
y_ = tf.Variable(tf.constant([5.]), name='y_')

p = w*x
y = p+b
s = -y
t = s +y_
f = t*t

gx, gb, gw, gp, gy, gy_,gs, gt, gf = tf.gradients(f, [x, b, w, p, y, y_,s, t, f])

init = tf.initialize_all_variables()

opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f)

with tf.Session() as sess:
    sess.run(init)
    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f' % (sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))
    print '---------- run GradientDescentOptimizer ----------'
    sess.run(train)

    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f'%(sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))

Execution result

sess.run(train)After x and y_The value of is also updated.



x:2.00, w:3.00, b:1.00 p:6.00, y:7.00, y_:5.00 s:-7.00, t:-2.00, f:4.00 ---------- gradient ---------- gx:12.00, gw:8.00, gb: 4.00 gp:4.00, gy:4.00, gy_:-4.00 gs:-4.00, gt:-4.00, gf:1.00 ---------- run GradientDescentOptimizer ---------- x:-10.00, w:-5.00, b:-3.00 p:50.00, y:47.00, y_:9.00 s:-47.00, t:-38.00, f:1444.00 ---------- gradient ---------- gx:-380.00, gw:-760.00, gb: 76.00 gp:76.00, gy:76.00, gy_:-76.00 gs:-76.00, gt:-76.00, gf:1.00


## When trainable is specified


#### **`python`**
```lang

import tensorflow as tf
import numpy as np

w = tf.Variable(tf.constant([3.]), name='w')
b = tf.Variable(tf.constant([1.]), name='b')
x = tf.Variable(tf.constant([2.]), name='x', trainable=False)
y_ = tf.Variable(tf.constant([5.]), name='y_', trainable=False)

p = w*x
y = p+b
s = -y
t = s +y_
f = t*t

gx, gb, gw, gp, gy, gy_,gs, gt, gf = tf.gradients(f, [x, b, w, p, y, y_,s, t, f])

init = tf.initialize_all_variables()

opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f)

with tf.Session() as sess:
    sess.run(init)
    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f' % (sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))
    print '---------- run GradientDescentOptimizer ----------'
    sess.run(train)

    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f'%(sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))

Execution result

sess.run(train)Even after x and y_You can see that the value of is also unchanged.



x:2.00, w:3.00, b:1.00 p:6.00, y:7.00, y_:5.00 s:-7.00, t:-2.00, f:4.00 ---------- gradient ---------- gx:12.00, gw:8.00, gb: 4.00 gp:4.00, gy:4.00, gy_:-4.00 gs:-4.00, gt:-4.00, gf:1.00 ---------- run GradientDescentOptimizer ---------- x:2.00, w:-5.00, b:-3.00 p:-10.00, y:-13.00, y_:5.00 s:13.00, t:18.00, f:324.00 ---------- gradient ---------- gx:180.00, gw:-72.00, gb: -36.00 gp:-36.00, gy:-36.00, gy_:36.00 gs:36.00, gt:36.00, gf:1.00


## When passing a list of variables to update to Optimizer
 If you want to pass a list of variables to Optimizer, pass var_list as an argument to minimize.


#### **`python`**
```lang

import tensorflow as tf
import numpy as np

w = tf.Variable(tf.constant([3.]), name='w')
b = tf.Variable(tf.constant([1.]), name='b')
x = tf.Variable(tf.constant([2.]), name='x')
y_ = tf.Variable(tf.constant([5.]), name='y_')

p = w*x
y = p+b
s = -y
t = s +y_
f = t*t


gx, gb, gw, gp, gy, gy_,gs, gt, gf = tf.gradients(f, [x, b, w, p, y, y_,s, t, f])

init = tf.initialize_all_variables()

opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f, var_list=[w,b])

with tf.Session() as sess:
    sess.run(init)
    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f' % (sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))
    print '---------- run GradientDescentOptimizer ----------'
    sess.run(train)

    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f'%(sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))

Execution result

x:2.00, w:3.00, b:1.00
p:6.00, y:7.00, y_:5.00
s:-7.00, t:-2.00, f:4.00
---------- gradient ----------
gx:12.00, gw:8.00, gb: 4.00
gp:4.00, gy:4.00, gy_:-4.00
gs:-4.00, gt:-4.00, gf:1.00
---------- run GradientDescentOptimizer ----------
x:2.00, w:-5.00, b:-3.00
p:-10.00, y:-13.00, y_:5.00
s:13.00, t:18.00, f:324.00
---------- gradient ----------
gx:180.00, gw:-72.00, gb: -36.00
gp:-36.00, gy:-36.00, gy_:36.00
gs:36.00, gt:36.00, gf:1.00

If you don't want to bother listing variables in var_list, using scope makes it a little easier. The procedure is as follows.

  1. Use scope when declaring variables.
  2. Specify scope with get_collection to get the variable list of that scope

python


import tensorflow as tf
import numpy as np

with tf.variable_scope("params"):
    w = tf.Variable(tf.constant([3.]), name='w')
    b = tf.Variable(tf.constant([1.]), name='b')

with tf.variable_scope("input"):
    x = tf.Variable(tf.constant([2.]), name='x')
    y_ = tf.Variable(tf.constant([5.]), name='y_')
    
with tf.variable_scope("intermediate"):
    p = w*x
    y = p+b
    s = -y
    t = s +y_
    f = t*t    


gx, gb, gw, gp, gy, gy_,gs, gt, gf = tf.gradients(f, [x, b, w, p, y, y_,s, t, f])

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="params")
print 'train_vars'
for v in train_vars:
    print v.name

init = tf.initialize_all_variables()

opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f, var_list=train_vars)

with tf.Session() as sess:
    sess.run(init)
    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f' % (sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))
    print '---------- run GradientDescentOptimizer ----------'
    sess.run(train)

    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f'%(sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))

Execution result

train_vars
params/w:0
params/b:0
x:2.00, w:3.00, b:1.00
p:6.00, y:7.00, y_:5.00
s:-7.00, t:-2.00, f:4.00
---------- gradient ----------
gx:12.00, gw:8.00, gb: 4.00
gp:4.00, gy:4.00, gy_:-4.00
gs:-4.00, gt:-4.00, gf:1.00
---------- run GradientDescentOptimizer ----------
x:2.00, w:-5.00, b:-3.00
p:-10.00, y:-13.00, y_:5.00
s:13.00, t:18.00, f:324.00
---------- gradient ----------
gx:180.00, gw:-72.00, gb: -36.00
gp:-36.00, gy:-36.00, gy_:36.00
gs:36.00, gt:36.00, gf:1.00

Recommended Posts

[TF] How to specify variables to update with Optimizer
How to specify the NIC to scan with amazon-dash
How to update easy_install
How to specify attributes with Mock of python
[Python] How to specify the download location with youtube-dl
How to update Python Tkinter to 8.6
How to cast with Theano
How to define Go variables
How to Alter with SQLAlchemy?
How to separate strings with','
How to RDP with Fedora31
How to Delete with SQLAlchemy?
How to write to update Datastore to async with Google Apps Engine
How to update FC2 blog etc. using XMLRPC with python
How to cancel RT with tweepy
Python: How to use async with
How to update Spyder in Anaconda
How to use virtualenv with PowerShell
How to deal with imbalanced data
How to install python-pip with ubuntu20.04LTS
How to deal with imbalanced data
How to get started with Scrapy
How to get started with Python
How to deal with DistributionNotFound errors
How to get started with Django
How to Data Augmentation with PyTorch
How to use FTP with Python
How to calculate date with python
How to install mysql-connector with pip3
How to INNER JOIN with SQLAlchemy
How to install Anaconda with pyenv
How to authenticate with Django Part 2
How to authenticate with Django Part 3
How to do arithmetic with Django template
[Blender] How to set shape_key with script
How to title multiple figures with matplotlib
How to get parent id with sqlalchemy
How to dynamically define variables in Python
How to add a package with PyCharm
How to update to Chainer 2.0 (Windows + CUDA 8.0 + CUDNN)
How to update Google Sheets from Python
How to install DLIB with 2020 / CUDA enabled
How to use ManyToManyField with Django's Admin
How to use OpenVPN with Ubuntu 18.04.3 LTS
How to use Cmder with PyCharm (Windows)
[TF] How to use Tensorboard from Keras
How to manually update the AMP cache
How to prevent package updates with apt
How to work with BigQuery in Python
How to update php on Amazon linux 2
How to deal with enum compatibility errors
How to use Japanese with NLTK plot
How to do portmanteau test with python
How to search Google Drive with Google Colaboratory
How to display python Japanese with lolipop
How to use jupyter notebook with ABCI
How to specify non-check target in Flake8
How to power off Linux with Ultra96-V2
How to update security on CentOS Linux 8
"How to pass PATH" to learn with homebrew
How to scrape websites created with SPA