首页 > 技术文章 > 3、会话tf.Session()

pengzhonglian 2019-11-11 22:13 原文


 ①tf.Session()

运行TensorFlow操作图的类,使用默认注册的图(可以指定运行图)

 1 import os
 2 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #去掉警告,将警告级别提升
 3 
 4 # 创建一张图
 5 g = tf.Graph()
 6 
 7 with g.as_default(): #作为默认图
 8     c = tf.constant(11)
 9     print(c.graph)
10 
11 a = tf.constant(2)   #定义一个常量
12 b = tf.constant(4)
13 sum = tf.add(a,b)    #加法操作
14 gr = tf.get_default_graph()
15 
16 #一个会话只能使用一张图,默认是注册图,即图gr
17 # with tf.Session() as sess:  #上下文管理
18 #     print(sess.run(sum))   #run运行加法op
19 #     print(sess.run(c))  # (Tensor Tensor("Const:0", shape=(), dtype=int32) is not an element of this graph.)
20 
21 #在会话中指定图运行
22 with tf.Session(graph=g) as sess:  #上下文管理
23     print(sess.run(c))

输出:

<tensorflow.python.framework.ops.Graph object at 0x000002486656CD30>

11

 

②会话资源

会话拥有很多资源,如tf.Variable,tf.QueueBase和tf.ReaderBase等,会话结束后需要进行资源释放

  • 使用方法 

   1、sess = tf.Session(),sess.run(),sess.close()

   2、使用上下文管理器

      with tf.Session() as sess:

        sess.run()

 

tensorflow可以分为前端系统(定义程序的图的结构)和后端系统(运算图的结构,用cpu,gpu进行运算)

会话的功能:

  • 运算图的结构
  • 分配资源计算
  • 掌握资源(变量的资源,队列,线程),所以一旦会话结束,所有的资源将不能再使用

 

③打印设备信息

config=tf.ConfigProto(log_device_placement=True)
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:  #上下文管理
    # print(sess.run(sum))   #run运行加法op
    print("a.graph:",a.graph)

输出:

Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 1650, pci bus id: 0000:01:00.0, compute capability: 7.5

 

④交互式操作  tf.InteractiveSession()

在命令行进行测试时使用

 

只要有会话的上下文环境,就可以使用操作方便的eval()

1 a = tf.constant(2)   #定义一个常量
2 b = tf.constant(4)
3 sum1 = tf.add(a,b)    #加法操作
4 
5 with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:  #上下文管理
6     print(sess.run(sum1))   #run运行加法op
7     print(sum1.eval())

输出:

Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 1650, pci bus id: 0000:01:00.0, compute capability: 7.5
Add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
Const: (Const): /job:localhost/replica:0/task:0/device:GPU:0
Const_1: (Const): /job:localhost/replica:0/task:0/device:GPU:0
6
6

 

⑤会话中的run方法

  run(fetches, feed_ dict=None,graph=None)    运行ops和计算tensor

  • 嵌套列表,元组。  namedtuple,dict或OrderedDict( 重载的运算符也能运行)
  • feed_dict允许调用者覆盖图中指定张量的值,提供给placeholder使用
  • 返回值异常

    RuntimeError:如果它Session处于无效状态(例如已关闭)。
    TypeError:如果fetches或feed_ dict键是不 合适的类型。
    ValueError:如果fetches或feed_ dict键 无效或引用Tensor不存在。

 (1)运行多个op

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #去掉警告,将警告级别提升

a = tf.constant(2)   #定义一个常量
b = tf.constant(4)
sum1 = tf.add(a,b)    #加法操作

with tf.Session() as sess:  #上下文管理
    print(sess.run(sum1))   #run运行加法op
    # print(sess.run(a,b,sum1)) #错误,这样相当于将a,b和sum1当作三个参数
    print(sess.run([a,b,sum1])) #要将a,b和sum1当作一个整体,放进列表中或者元组中

输出:

6
[2, 4, 6]

(2)重载运算符

  •  两个非tensor的量相加(不是op)不能使用run方法
1  var1 = 3.0
2  var2 = 5.0
3  sum2 = var1 + var2
4 
5  with tf.Session() as sess:  #上下文管理
6      print(sess.run(sum2))   

输出:

TypeError: Fetch argument 8.0 has invalid type <class 'float'>, must be a string or Tensor. (Can not convert a float into a Tensor or Operation.)

 

  • 一个tensor和一个非tensor运算,非tensor会被重载为tensor
 1  import tensorflow as tf
 2  import os
 3  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #去掉警告,将警告级别提升
 4 
 5  a = tf.constant(2.0)   #定义一个常量
 6  b = tf.constant(4.0)
 7  # sum1 = tf.add(a,b)    #加法操作
 8 
 9  var1 = 3.0
10  #默认会给运算符重载成op类型
11  sum3 = a + var1  #a和var1的数据类型要匹配
12 
13  with tf.Session() as sess:  #上下文管理
14      print(sess.run(sum3))
15      print(a)  #直接打印是一个tensor Tensor("Const:0", shape=(), dtype=float32)
16      print(var1) # 3.0
17      print(sum3)  #Tensor("add:0", shape=(), dtype=float32)

输出

5.0
Tensor("Const:0", shape=(), dtype=float32)
3.0
Tensor("add:0", shape=(), dtype=float32)

 (3)feed_dict参数,允许调用者覆盖图中指定张量的值,提供给placeholder使用

训练模型时,不知道每批次有多少个样本 ,即有些量不固定,我们需要实时的提供数据去进行训练

 1 import tensorflow as tf
 2 import os
 3 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #去掉警告,将警告级别提升
 4 
 5 #placeholder是一个占位符,通常与feed_dict配合使用
 6 plt1 = tf.placeholder(tf.float32,[2,3])
 7 
 8 plt2 = tf.placeholder(tf.float32,[None,3]) #样本数量不固定用None,Tensor("Placeholder_1:0", shape=(?, 3), dtype=float32)
 9 print(plt2)
10 
11 with tf.Session() as sess:  #上下文管理
12     #相当于在运算的时候实时的提供数据进行训练
13     print(“plt1:\n”,sess.run(plt1,feed_dict={ plt1:[[1,2,3],[4,5,6]]})) #feed_dict是一个字典,key是变量,value是2x3的值
14     print("plt2:\n",sess.run(plt2, feed_dict={plt2: [[1, 2, 3], [4, 5, 6],[7,8,9]]}))#因为行数不固定,所以可接收任意行数的值

输出:

Tensor("Placeholder_1:0", shape=(?, 3), dtype=float32)

plt1: [[
1. 2. 3.] [4. 5. 6.]] plt2: [[1. 2. 3.] [4. 5. 6.] [7. 8. 9.]]

 

会话最重要的是run,作用是运行的图的结构,即运行一些op,再得到结果

(4)TensorFlow Feed操作

意义:在程序执行的时候,不确定输入的是什么,提前“占个坑”

语法:placeholder提供占位符,在会话中run的时候通过feed_dict指定参数

  • placeholder(dtype, shape=None, name=None)

    shape要用列表表示,默认是一个常量

  • run(self, fetches, feed_dict=None, options=None, run_metadata=None)
例1:
1
import tensorflow as tf 2 import os 3 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #去掉警告,将警告级别提升 4 5 input1 = tf.placeholder(tf.float32,[2,1]) #[2,1] 2行1列 6 input2 = tf.placeholder(tf.float32) #默认是一个数 7 output = tf.add(input1,input2) 8 # placeholder 提供占位符,run的时候通过feed_dict指定参数 9 10 with tf.Session() as sess: 11 result = sess.run(output,feed_dict={input1:[[10],[1]],input2:20}) 12 print(result)

 

输出:

[[30.]
 [21.]]

 

 例2:
1
import tensorflow as tf 2 import os 3 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #去掉警告,将警告级别提升 4 5 input1 = tf.placeholder(tf.float32,[2,2]) 6 input2 = tf.placeholder(tf.float32,1) 7 8 sum = tf.add(input1,input2) 9 10 with tf.Session() as sess: 11 result = sess.run(sum, feed_dict={input1:[[1,2],[3,4]],input2:[3]}) 12 print(result)

 输出:

[[4. 5.]
 [6. 7.]]

推荐阅读