TensorFlow 1.x学习(系列二 :2):张量的动态形状与静态形状,基本的张量api

  • Post author:
  • Post category:其他


张量的阶和数据类型:

1:TensorFlow 的基本数据格式

2:一个类型化的N维数组(tf.Tensor)

3:三部分,名字,形状,数据类型

张量属性:

.graph 张量所属的默认图

.op 张量的操作名

.name 张量的字符串描述

.shape 张量形状

TensorFlow:打印出来的形状表示

0维:()

1维:(n)

2维:(n,m)

3维: (n,m,l)



1.张量的动态形状与静态形状

.TensorFlow中,张量具有静态形状和动态形状

.静态形状:创建一个张量,初始状态的形状

.tf.Tensor.get_shape:获取静态形状

.tf.Tensor.set_shape():更新Tensor对象的静态形状

.动态形状:一种描述原始张量在执行过程中的一种形状(动态变化)

.tf.reshape:创建一个具有不同动态形状的新张量



设置静态形状

import tensorflow as tf
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:521: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
plt = tf.placeholder(tf.float32,[None,2])
plt
<tf.Tensor 'Placeholder:0' shape=(?, 2) dtype=float32>
plt.set_shape([3,2])
plt
<tf.Tensor 'Placeholder:0' shape=(3, 2) dtype=float32>



错误的修改方式1:静态形状不能跨维度修改

plt.set_shape([3,2,2])
---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\tensor_shape.py in merge_with(self, other)
    578       try:
--> 579         self.assert_same_rank(other)
    580         new_dims = []


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\tensor_shape.py in assert_same_rank(self, other)
    623         raise ValueError("Shapes %s and %s must have the same rank" % (self,
--> 624                                                                        other))
    625 


ValueError: Shapes (3, 2) and (3, 2, 2) must have the same rank



During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)

<ipython-input-4-c9e5d4ac7673> in <module>
----> 1 plt.set_shape([3,2,2])


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in set_shape(self, shape)
    468     """
    469     if not _USE_C_API:
--> 470       self._shape_val = self._shape_val.merge_with(shape)
    471       return
    472     if not isinstance(shape, tensor_shape.TensorShape):


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\tensor_shape.py in merge_with(self, other)
    583         return TensorShape(new_dims)
    584       except ValueError:
--> 585         raise ValueError("Shapes %s and %s are not compatible" % (self, other))
    586 
    587   def concatenate(self, other):


ValueError: Shapes (3, 2) and (3, 2, 2) are not compatible



错误的修改方式2:再次设置静态形状(会报错,只能设置一次)

plt.set_shape([4,2])

# ValueError: Shapes (3, 2) and (4, 2) are not compatible
---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\tensor_shape.py in merge_with(self, other)
    581         for i, dim in enumerate(self._dims):
--> 582           new_dims.append(dim.merge_with(other[i]))
    583         return TensorShape(new_dims)


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\tensor_shape.py in merge_with(self, other)
    139     other = as_dimension(other)
--> 140     self.assert_is_compatible_with(other)
    141     if self._value is None:


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\tensor_shape.py in assert_is_compatible_with(self, other)
    112       raise ValueError("Dimensions %s and %s are not compatible" % (self,
--> 113                                                                     other))
    114 


ValueError: Dimensions 3 and 4 are not compatible



During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)

<ipython-input-5-3a8be0759c51> in <module>
----> 1 plt.set_shape([4,2])
      2 
      3 # ValueError: Shapes (3, 2) and (4, 2) are not compatible


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in set_shape(self, shape)
    468     """
    469     if not _USE_C_API:
--> 470       self._shape_val = self._shape_val.merge_with(shape)
    471       return
    472     if not isinstance(shape, tensor_shape.TensorShape):


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\tensor_shape.py in merge_with(self, other)
    583         return TensorShape(new_dims)
    584       except ValueError:
--> 585         raise ValueError("Shapes %s and %s are not compatible" % (self, other))
    586 
    587   def concatenate(self, other):


ValueError: Shapes (3, 2) and (4, 2) are not compatible



通过动态张量再次修改(创建一个新的张量)

plt_reshape = tf.reshape(plt,[2,3])
plt_reshape
<tf.Tensor 'Reshape:0' shape=(2, 3) dtype=float32>



动态形状修改一定注意元素数量匹配

plt_reshape2 = tf.reshape(plt,[3,3])
plt_reshape2

# ValueError: Cannot reshape a tensor with 6 elements to shape [3,3] (9 elements) for 'Reshape_2' (op: 'Reshape') with input shapes: [3,2], [2] and with input tensors computed as partial shapes: input[1] = [3,3].

---------------------------------------------------------------------------

InvalidArgumentError                      Traceback (most recent call last)

C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn)
    685           graph_def_version, node_def_str, input_shapes, input_tensors,
--> 686           input_tensors_as_shapes, status)
    687   except errors.InvalidArgumentError as err:


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    515             compat.as_text(c_api.TF_Message(self.status.status)),
--> 516             c_api.TF_GetCode(self.status.status))
    517     # Delete the underlying status object from memory otherwise it stays alive


InvalidArgumentError: Cannot reshape a tensor with 6 elements to shape [3,3] (9 elements) for 'Reshape_1' (op: 'Reshape') with input shapes: [3,2], [2] and with input tensors computed as partial shapes: input[1] = [3,3].



During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)

<ipython-input-7-55ae297cd228> in <module>
----> 1 plt_reshape2 = tf.reshape(plt,[3,3])
      2 plt_reshape2
      3 
      4 # ValueError: Cannot reshape a tensor with 6 elements to shape [3,3] (9 elements) for 'Reshape_2' (op: 'Reshape') with input shapes: [3,2], [2] and with input tensors computed as partial shapes: input[1] = [3,3].


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\ops\gen_array_ops.py in reshape(tensor, shape, name)
   5099   if _ctx.in_graph_mode():
   5100     _, _, _op = _op_def_lib._apply_op_helper(
-> 5101         "Reshape", tensor=tensor, shape=shape, name=name)
   5102     _result = _op.outputs[:]
   5103     _inputs_flat = _op.inputs


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    785         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    786                          input_types=input_types, attrs=attr_protos,
--> 787                          op_def=op_def)
    788       return output_structure, op_def.is_stateful, op
    789 


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in create_op(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_shapes, compute_device)
   3271         op_def=op_def)
   3272     self._create_op_helper(ret, compute_shapes=compute_shapes,
-> 3273                            compute_device=compute_device)
   3274     return ret
   3275 


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in _create_op_helper(self, op, compute_shapes, compute_device)
   3311     # compute_shapes argument.
   3312     if op._c_op or compute_shapes:  # pylint: disable=protected-access
-> 3313       set_shapes_for_outputs(op)
   3314     # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
   3315     self._add_op(op)


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in set_shapes_for_outputs(op)
   2499     return _set_shapes_for_outputs_c_api(op)
   2500   else:
-> 2501     return _set_shapes_for_outputs(op)
   2502 
   2503 


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in _set_shapes_for_outputs(op)
   2472       shape_func = _call_cpp_shape_fn_and_require_op
   2473 
-> 2474   shapes = shape_func(op)
   2475   if shapes is None:
   2476     raise RuntimeError(


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in call_with_requiring(op)
   2402 
   2403   def call_with_requiring(op):
-> 2404     return call_cpp_shape_fn(op, require_shape_fn=True)
   2405 
   2406   _call_cpp_shape_fn_and_require_op = call_with_requiring


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\common_shapes.py in call_cpp_shape_fn(op, require_shape_fn)
    625     res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
    626                                   input_tensors_as_shapes_needed,
--> 627                                   require_shape_fn)
    628     if not isinstance(res, dict):
    629       # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).


C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn)
    689       missing_shape_fn = True
    690     else:
--> 691       raise ValueError(err.message)
    692 
    693   if missing_shape_fn:


ValueError: Cannot reshape a tensor with 6 elements to shape [3,3] (9 elements) for 'Reshape_1' (op: 'Reshape') with input shapes: [3,2], [2] and with input tensors computed as partial shapes: input[1] = [3,3].



总结

对于静态形状来说,一旦张量形状固定了,不能再次设置静态形状,不能跨维度修改

动态形状可以去创建一个新的张量,改变的时候一定要注意元素数量要匹配



2.张量操作



生成张量



固定值张量:

tf.zeros(shepe,dtype = tf.float32,name = None) # 创建所有元素值为0的张量

tf.ones(shape,dtype = tf.float32,name = None) # 创建一个所有元素值为1的张量

生成全部元素为0和元素为1的张量

zero = tf.zeros([3,4],tf.float32)
zero
<tf.Tensor 'zeros:0' shape=(3, 4) dtype=float32>
one = tf.ones([3,4],tf.float32)
one
<tf.Tensor 'ones:0' shape=(3, 4) dtype=float32>

结果输出

with tf.Session() as sess:
    print(sess.run([zero,one]))
[array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32), array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32)]



随机值张量:

tf.random_normal(shape,mean = 0.0,stddev = 1.0,dtype = tf.float32,seed = None,name = None) # 从正态分布中输出随机值,由随机正态分布的数字组成的矩阵

生成随机张量

random_1 = tf.random_normal([3,4],mean = 0.0,stddev = 1,dtype = tf.float32)
random_1
<tf.Tensor 'random_normal:0' shape=(3, 4) dtype=float32>

结果输出

with tf.Session() as sess:
    print(random_1,'\n')
    print(sess.run(random_1),'\n')
    print(random_1)
Tensor("random_normal:0", shape=(3, 4), dtype=float32) 

[[ 1.1654192   0.7380267   0.44780177  0.9135542 ]
 [ 0.22852856 -0.0483947   0.3628924   0.13235609]
 [ 0.6036132  -0.47465008  0.10996567 -1.3579988 ]] 

Tensor("random_normal:0", shape=(3, 4), dtype=float32)



张量变换



张量变换(改变张量中数据类型)

tf.cast(x,dtype,name = None)

tf.sring_to_number(string_tensor,out_type = None,name = None)

tf.to_double(x,name = ‘ToDouble’);tf.to_float(x,name = ‘ToFloat’);…

类型转换(其中cast最常用)

constant_2 = tf.constant([1,2,3,4,5,6],shape = [2,3])
zeros_2 = tf.zeros([3.0,3.0],dtype = tf.float32)
ones_2 = tf.ones([3,3],dtype = tf.int32)

convert_1 = tf.cast(constant_2,tf.float32)
convert_2 = tf.cast(zeros_2,tf.int32)
convert_3 = tf.to_int32(ones_2,name = 'ToInt32')

转换结果输出

with tf.Session() as sess:
    print(sess.run(convert_1))
    print(sess.run(convert_2))
    print(sess.run(convert_3))
[[1. 2. 3.]
 [4. 5. 6.]]
[[0 0 0]
 [0 0 0]
 [0 0 0]]
[[1 1 1]
 [1 1 1]
 [1 1 1]]



张量变换(改变形状)

tf.reshape(tensor,shape,name = None)

a = tf.ones([2,4])

with tf.Session() as sess:
    print(sess.run(tf.reshape(a,[4,2])))
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]



张量合并

tf.concat()

a = tf.ones([2,2])
b = tf.zeros([2,2])

with tf.Session() as sess:
    print(sess.run(tf.concat([a,b],axis = 0)))
    print(sess.run(tf.concat([a,b],axis = 1)))
[[1. 1.]
 [1. 1.]
 [0. 0.]
 [0. 0.]]
[[1. 1. 0. 0.]
 [1. 1. 0. 0.]]



版权声明:本文为qq_33489955原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。