تحسين النماذج في Tensorflow 1.x

بينما يفقد Tensorflow مكانته في بيئة البحث ، لا يزال يتمتع بشعبية في التطوير العملي. تتمثل إحدى نقاط القوة في TF التي تبقيها قائمة في القدرة على تحسين النماذج للنشر في البيئات محدودة الموارد. هناك أطر عمل خاصة لهذا: Tensorflow Lite للأجهزة المحمولة وخدمة Tensorflowللاستخدام الصناعي. هناك دروس كافية حول استخدامها على الويب (وحتى على Habré). في هذه المقالة ، قمنا بتجميع خبرتنا في تحسين النماذج دون استخدام هذه الأطر. سننظر في بعض الأساليب والمكتبات التي تنجز المهمة المطروحة ، وسنصف كيف يمكنك توفير مساحة القرص وذاكرة الوصول العشوائي ، ونقاط القوة والضعف في كل نهج ، وبعض التأثيرات غير المتوقعة التي واجهناها.



في أي ظروف نعمل



إحدى مهام البرمجة اللغوية العصبية الكلاسيكية هي التصنيف الموضوعي للنصوص القصيرة. يتم تمثيل المصنفات بالعديد من البنى المختلفة ، بدءًا من الأساليب الكلاسيكية مثل SVC إلى بنيات المحولات مثل BERT ومشتقاتها. سننظر في CNN - النماذج التلافيفية.



أحد القيود المهمة بالنسبة لنا هو الحاجة إلى تدريب النماذج واستخدامها (كجزء من المنتج) على الأجهزة التي لا تحتوي على وحدة معالجة الرسومات. هذا يؤثر في المقام الأول على سرعة التعلم والاستدلال.



شرط آخر هو أن يتم تدريب نماذج التصنيف واستخدامها في مجموعات من عدة قطع. يمكن لمجموعة من النماذج ، حتى البسيطة منها ، استخدام الكثير من الموارد ، وخاصة ذاكرة الوصول العشوائي. نحن نستخدم حلنا الخاص لخدمة النماذج ، ومع ذلك ، إذا كنت بحاجة إلى العمل مع مجموعات من النماذج ، فقم بإلقاء نظرة على خدمة Tensorflow .



لقد واجهتنا الحاجة إلى تحسين النموذج على إصدار TF 1.x ، والذي يعتبر الآن رسميًا قديمًا. بالنسبة إلى TF 2.x ، فإن العديد من التقنيات التي تمت مناقشتها إما غير ملائمة أو مدمجة في API القياسي ، وبالتالي فإن عملية التحسين بسيطة للغاية.



دعنا نلقي نظرة على هيكل نموذجنا أولاً.



كيف يعمل نموذج TF



خذ بعين الاعتبار ما يسمى بـ Shallow CNN - شبكة ذات طبقة تلافيفية واحدة وعدة مرشحات. لقد عمل هذا النموذج جيدًا بما يكفي لتصنيف النص على تمثيلات الكلمات المتجهة.





من أجل التبسيط ، سنستخدم مجموعة محددة مسبقًا من تمثيلات المتجهات للأبعاد v x k ، حيث v هي حجم القاموس ، k هي بُعد التضمينات.



:



  • Embedding-, .
  • w x k. , (1, 1, 2, 3) 4 , 1 , 2 3 , .
  • Max-pooling .
  • , dropout- softmax- .


Adam, .



: .



, , 128 c w = 2 k = 300 () [filter_height, filter_width, in_channels, output_channels] — , 2*300*1*128 = 76800 float32, , 76800*(32/8) = 307200 .



? ( 220 . ) 300 265 . , .



TF . ( ), , , — ( ), . (). :



الرسم البياني الحسابي





. , : SavedModel. , .



Checkpoint



, Saver API:



saver = tf.train.Saver(save_relative_paths=True)
ckpt_filepath = saver.save(sess, "cnn.ckpt"), global_step=0)


global_step , , — cnn-ckpt-0.



<model_path>/cnn_ckpt :





checkpoint — . , TF . , .



.data , . , — 800 . , (≈265 ). ( ). , .



.index .



.meta — , (, , ), GraphDef, . , . — .meta , ? , TF - embedding-. , , , , , . , , :



with tf.Session() as sess:
   saver = tf.train.import_meta_graph('models/ckpt_model/cnn_ckpt/cnn.ckpt-0.meta')  # load meta
   for n in tf.get_default_graph().as_graph_def().node:
       print(n.name, n['attr'].shape)


.



SavedModel



, . . API tf.saved_model. tf.saved_model, TF- (TFLite, TensorFlow.js, TensorFlow Serving, TensorFlow Hub).



:





saved_model.pb, , , .meta , (, ), API, ( CLI, ).



SavedModel — , . “” . , , - — , .





, CNN-, TF 1.x, . .



, 1 , :





  1. . , , ( tools.optimize_for_inference ).


  2. . , , — , tf.trainable_variables().


  3. , . , (. BERT).


  4. . , . .




, , . , forward pass, . , . 1 265 .



TF 1.x , .



( ) GraphDef:



graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()


. : tf.python.tools.freeze_graph tf.graph_util.convert_variables_to_constants. ( ) (, ['output/predictions']), , , . .



output_graph_def = graph_util.convert_variables_to_constants(self.sess, input_graph_def, output_node_names)


, .

freeze_graph() ( , , ). graph_util.convert_variables_to_constants() :



with tf.io.gfile.GFile('graph.pb', 'wb') as f:
    f.write(output_graph_def.SerializeToString())


266 , :



#  GraphDef  

with tf.io.gfile.GFile(graph_filepath, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
    #   
    self.input_x = tf.placeholder(tf.int32, [None, self.properties.max_len], name="input_x")
    self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
    #        graph_def
    input_map = {'input_x': self.input_x, 'dropout_keep_prob': self.dropout_keep_prob}
    tf.import_graph_def(graph_def, input_map)


, import:



predictions = graph.get_tensor_by_name('import/output/predictions:0')


:



feed_dict = {self.input_x: encode_sentence(sentence), self.dropout_keep_prob: 1.}
sess.run(self.predictions, feed_dict)


, :



  1. . , sess.run(...). , CPU 20 ms, ~2700 ms. , . SavedModel .
  2. RAM. RAM, . ~265 , . , TF GraphDef .
  3. – RAM TF . 1.15, TF 1.x, 118 MiB, 1.14 – 3 MiB.




, . ? / TF- tf.train.Saver. , , , :



  • MetaGraph


tf.train.Saver . , :



saver = tf.train.Saver(var_list=tf.trainable_variables())


MetaGraph . , meta . MetaGraph save:



ckpt_filepath = saver.save(self.sess, filepath, write_meta_graph=False)


1014 M 265 M ( , ).





Pruning — , , . , .



, TF 1.x:



  • Grappler: c tensorflow
  • Pruning API: google-research
  • Graph Transform Tool:


, — tensorflow, Grappler. Grappler . , set_experimental_options. , zip . , zip , . Grappler .



google-research mask threshold, . . , , mask threshold, , , . .



Grappler, . : ? , ? , 0.99 . , mc, hex :



hex-



, , . . -, . -, , , , . , .



CNN. .





, . Graph transform tool.



quantize_weights 8 . , 8- . , , - .



quantize_nodes 8- . .



, - . quantize_weights - , 4 .



, , TensorFlow Lite, .





— , . 64 (32) , .



RAM Ubuntu ( numpy int64) . 220 , int32, int16. .





tf-. float16. , , ( 10%), ( 10 ). , , epsilon learning_rate . , , .



RAM



, . , .





, . . .



QA-



Q: -, - ?



A: , . word2vec. ( , , min count, learning rate), 220 ( — 265 MB) CNN, 439 (510 MB).



- , , , - . , ( ). , . YouTokenToMe, , , .. , .., . . , , , . 30 (37 MB) , 3.7 CPU 2.6 GPU. ( ), OOV-.



Q: , , ?



A: , .



:



1. :



with tf.gfile.GFile(path_to_pb, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name='')
    return graph


2. "" :



sess.run(restored_variable_names)


3. , .

4. , , :



tf.Variable(tensors_to_restore["output/W:0"], name="W")


, .



, , .



لم نحاول إعادة تدريب النماذج المضغوطة بواسطة بقية الطرق الموصوفة ، ولكن نظريًا لا ينبغي أن يكون هناك أي مشاكل مع هذا.



س: هل هناك طرق أخرى لتقليل التحسين لم تفكر فيها؟



ج: لدينا العديد من الأفكار التي لم ندركها قط. أولاً ، الطي الثابت هو "طي" لمجموعة فرعية من عقد الرسم البياني ، وهو الحساب المسبق لقيم أجزاء الرسم البياني التي تعتمد بشكل ضعيف على بيانات الإدخال. ثانيًا ، في نموذجنا ، يبدو أنه حل جيد لتطبيق تقليم حفلات الزفاف.




All Articles