最近,想使用谷歌的Attention OCR做中文文本识别,项目github地址
<https://github.com/A-bone1/Attention-ocr-Chinese-Version>
:https://github.com/A-bone1/Attention-ocr-Chinese-Version,中文介绍可参考CSDN博客
<https://blog.csdn.net/qq_40003316/article/details/80062023>
:https://blog.csdn.net/qq_40003316/article/details/80062023。       
研究后发现该模型的训练数据需要提供FSNS格式的训练数据,而官方也没有给出相关的文档,只给了一个stackoverflow的链接
https://stackoverflow.com/a/44461910/743658
<https://stackoverflow.com/a/44461910/743658>
,可是说的也不清楚。所以自己参考网上的一些办法,写了一个生成FSNS格式tfrecord的小代码。github地址
<https://github.com/A-bone1/FSNS-tfrecord-generate>
为:https://github.com/A-bone1/FSNS-tfrecord-generate。
      FSNS的具体格式在这篇论文有说:https://arxiv.org/pdf/1702.03970.pdf
<https://arxiv.org/pdf/1702.03970.pdf>
      但是,我们只需关心表四即可:
image/format表示图片的格式,是‘png’ ,如果你生的tfrecord是使用jpg格式,可改成‘raw’
image/encoded 表示图片的具体内容,占用一个string,以‘png’的格式编码

iamge/class表示图片真实的类别id,是37个int64数据,每一个int64对应一个字符编码,具体的映射方式在charset_size=134.txt文件中,要生成自己的数据需要自己创建类似的字典,如我自己创建的包含5400个中文的
dic.txt <https://github.com/A-bone1/FSNS-tfrecord-generate>。
image/unpadded_class 表示图片在没有被填充之前真实的id。
image/width:表示图片的像素的宽度
image/orig_width:表示图片在没有填充之前像素的宽度
image/height:表示图片的像素的高度,在tensorflow代码中,这一部分并没有写入代码,因为图片高度固定为150
image/test:占用一个string,是使用UTF-8编码的真实的字符形式的标记 
        下面直接上代码:(上传的代码是将jpg图片直接存储为tfrecord,速度较快,如果读者想生成png编码的tfrecord,可以参考我的
github <https://github.com/A-bone1/FSNS-tfrecord-generate>。

from random import shuffle import numpy as np import glob import tensorflow as
tf import cv2 import sys import os import PIL.Image as Image def
encode_utf8_string(text, length, dic, null_char_id=5462): char_ids_padded =
[null_char_id]*length char_ids_unpadded = [null_char_id]*len(text) for i in
range(len(text)): hash_id = dic[text[i]] char_ids_padded[i] = hash_id
char_ids_unpadded[i] = hash_id return char_ids_padded, char_ids_unpadded def
_bytes_feature(value): return
tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def
_int64_feature(value): return
tf.train.Feature(int64_list=tf.train.Int64List(value=value)) dict={} with
open('dic.txt', encoding="utf") as dict_file: for line in dict_file: (key,
value) = line.strip().split('\t') dict[value] = int(key) print((dict))
image_path = 'data/*/*.jpg' addrs_image = glob.glob(image_path) label_path =
'data/*/*.txt' addrs_label = glob.glob(label_path) print(len(addrs_image))
print(len(addrs_label)) tfrecord_writer =
tf.python_io.TFRecordWriter("tfexample_train") for j in
range(0,int(len(addrs_image))): # 这是写入操作可视化处理 print('Train data:
{}/{}'.format(j,int(len(addrs_image)))) sys.stdout.flush() img =
Image.open(addrs_image[j]) img = img.resize((600, 150), Image.ANTIALIAS)
np_data = np.array(img) image_data = img.tobytes() for text in
open(addrs_label[j], encoding="utf"): char_ids_padded, char_ids_unpadded =
encode_utf8_string( text=text, dic=dict, length=37, null_char_id=5462) example
= tf.train.Example(features=tf.train.Features( feature={ 'image/encoded':
_bytes_feature(image_data), 'image/format': _bytes_feature(b"raw"),
'image/width': _int64_feature([np_data.shape[1]]), 'image/orig_width':
_int64_feature([np_data.shape[1]]), 'image/class':
_int64_feature(char_ids_padded), 'image/unpadded_class':
_int64_feature(char_ids_unpadded), 'image/text': _bytes_feature(bytes(text,
'utf-8')), # 'height': _int64_feature([crop_data.shape[0]]), } ))
tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.close()
sys.stdout.flush()
  

友情链接
KaDraw流程图
API参考文档
OK工具箱
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:ixiaoyang8@qq.com
QQ群:637538335
关注微信