본문 바로가기
Deep Learning/Hands On Machine Learning

13.2 TFRecord 포맷

by 대소기 2021. 11. 12.

TFRecord 포맷

  • tensorflow는 대용량 데이터를 저장하기 위해 tfrecord라는 포맷을 사용한다.
  • tfrecord는 크기가 다른 여러가지 레코드를 저장하는 이진 포맷이다.
  • 각 레코드는 레코드 길이 CRC checksum(길이가 올바른지 체크하는), 실제 데이터, 데이터를 위한 CRC checksum으로 구성된다.
with tf.io.TFRecordWriter("my_data.tfrecord") as f:
    f.write(b"This is the first record")
    f.write(b"And this is the second record")
  • 위 코드와 같은 방법으로 tfrecord를 작성할 수 있다.
filepaths = ["my_data.tfrecord"]
dataset = tf.data.TFRecordDataset(filepaths)
for item in dataset:
    print(item)

#tf.Tensor(b'This is the first record', shape=(), dtype=string)
#tf.Tensor(b'And this is the second record', shape=(), dtype=string)
  • tfrecord를 불러올 때는 filepath를 인자로 사용해 tf.data.TFRecordDataset()을 통해 불러올 수 있다.
filepaths = ["my_test_{}.tfrecord".format(i) for i in range(5)]
for i, filepath in enumerate(filepaths):
    with tf.io.TFRecordWriter(filepath) as f:
        for j in range(3):
            f.write("File {} record {}".format(i, j).encode("utf-8"))

dataset = tf.data.TFRecordDataset(filepaths, num_parallel_reads=3)
for item in dataset:
    print(item)
  • list_files()와 interleave()를 사용했던 것 처럼 여러 파일에서 레코드를 위 코드처럼 번갈아가며 읽을 수도 있다.

13.2.1 압축된 TFRecord 파일

  • tfrecord file을 압축하여 저장할 수 있다.

options = tf.io.TFRecordOptions(compression_type="GZIP")
with tf.io.TFRecordWriter("my_compressed.tfrecord", options) as f:
    f.write(b"This is the first record")
    f.write(b"And this is the second record")
  • 저장하는 코드이다. options만 지정해주면 압축이 가능하다.

dataset = tf.data.TFRecordDataset(["my_compressed.tfrecord"],
                                  compression_type="GZIP")
for item in dataset:
    print(item)

13.2.2 프로토콜 버퍼 개요

프로토콜 버퍼

직렬화

  • 데이터를 파일로 저장하거나 네트워크 통신이 가능하도록 형식을 바꾸어 주는 것.

프로토콜 버퍼

  • 프로토콜 버퍼는 google이 개발한 이진 포맷으로 파일 저장이나 네트워크 전송 등을 위해 사용한다.
  • 직렬화 된 데이터를 이진 포맷으로 저장하기 때문에 더 적은 용량으로 데이터 전송이 가능하다.
  • 데이터를 직렬화 하기 때문에 Language Neutral하다.
  • jpg, png파일과 같은 이미지 파일들을 사용할 때 필요한 인코딩, 디코딩 작업이 필요 없이 직렬화된 데이터를 읽으면 되므로 편리하다.
  • 보통 데이터를 보관할 때 data, target을 분리하여 보관하게 되는데, 프로토콜 버퍼를 사용하면 직렬화 하여 하나의 파일로 보관할 수 있기 때문에 data와 target을 매칭해주는 코드를 추가적으로 작성할 필요가 없어져 불필요한 코드를 줄일 수 있다.

프로토콜 버퍼 생성

#person.proto로 파일 저장
%%writefile person.proto 
syntax = "proto3"; #protocol buffer version3
message Person {
  string name = 1;
  int32 id = 2;
  repeated string email = 3;
}
  • C언어의 구조체 형식과 비슷하게 프로토콜 버퍼를 만들 수 있다.
  • 1, 2, 3은 각각 필드 식별자로 데이터의 이진 표현에 사용된다.

프로토콜 버퍼 컴파일


!protoc person.proto --python_out=. --descriptor_set_out=person.desc --include_imports
  • 프로토콜 버퍼는 protoc라는 c언어 기반 컴파일러를 통해 컴파일을 진행하고 --python_out=. 옵션을 통해 python 클래스를 생성해야 사용할 수 있다.
!ls
# person.desc  person_pb2.py  person.proto
  • 컴파일이 끝나면 디렉토리에 person.desc, person_pb2.py 2개의 파일이 추가된다. 이중 pb2(protocol buffer 2)가 붙은 파일을 import하여 클래스를 사용할 수 있다.

from person_pb2 import Person

person=Person(name='AI', id=123, email=['a@b.com'])
person.name='Alice' #필드는 수정 가능하다.
person.email.append('c@d.com') 
s = person.SerializeToString()
s
#b'\n\x05Alice\x10{\x1a\x07a@b.com\x1a\x07c@d.com'
  • 프로토콜 버퍼를 사용한 Person 클래스를 사용해 person객체를 만들었다.
  • person 객체의 각 필드들은 수정이 가능하고, 반복 필드(배열)의 경우 인덱싱을 통해 참조 또한 가능하다.
  • person객체는 ParseFromString() 메소드를 통해 직렬화 할 수 있다.
  • 직렬화 한 데이터는 네트워크를 통해 전송이 가능하다.
from person_pb2 import Person
person2=Person() #객체 생성
person2.ParseFromString(s) #파싱
  • 네트워크를 통해 전송받은 직렬화된 데이터는 ParseFromString()메소드를 통해 파싱이 가능하다.
  • 혹은 직렬화된 데이터를 TFRecord파일로 저장한 후 읽고 파싱하는 것도 가능하다.
  • 하지만, SerializeToString()ParseFromString()은 텐서플로우 연산이 아니기 때문에 텐서플로우 함수에 포함할 수 없다.
  • 텐서플로우에서는 이러한 문제를 해결하기 위해 특별한 프로토콜 버퍼 정의를 가지고 있다. 이를 13.2.3에서 살펴본다.

13.2 텐서플로 프로토콜 버퍼

  • TFRecord파일에서 주로 사용하는 프로토콜 버퍼는 Example 프로토코 버퍼이다.
syntax = "proto3";

message BytesList { repeated bytes value = 1; }
message FloatList { repeated float value = 1 [packed = true]; }
message Int64List { repeated int64 value = 1 [packed = true]; }
message Feature {
    oneof kind {
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};
message Features { map<string, Feature> feature = 1; };
message Example { Features features = 1; };
  • Example 프로토콜 버퍼의 구조는 위와 같다.

Example 클래스를 사용해 객체 생성


from tensorflow.train import BytesList, FloatList, Int64List
from tensorflow.train import Feature, Features, Example

person_example=Example(
    features=Features(
        feature={
            "name" : Feature(bytes_list=BytesList(value=[b"Alice"])),
            "id" : Feature(int64_list=Int64List(value=[1,2,3])),
            "emails" : Feature(bytes_list=BytesList(value=[b"a@b.com",
                                                           b"c@d.com"]))
        }
    )
)

데이터 직렬화와 TFRecord형식으로 저장


with tf.io.TFRecordWrite("my_contacts.tfrecord") as f:
    f.write(person_example.SerializeToString())

13.2.3 Example 프로토콜 버퍼를 읽고 파싱하기


feature_description = {
    "name": tf.io.FixedLenFeature([], tf.string, default_value=""), 
    "id": tf.io.FixedLenFeature([], tf.int64, default_value=0),
    "emails": tf.io.VarLenFeature(tf.string),
}
for serialized_example in tf.data.TFRecordDataset(["my_contacts.tfrecord"]):
    parsed_example = tf.io.parse_single_example(serialized_example, #파싱
                                                feature_description)
  • Example 프로토콜 버퍼를 파싱하기 위해서는 parse_single_example()메소드를 사용하여야 한다.
  • Example 프로토콜 버퍼를 읽기 위해서는 feature description을 정의해서 parse_single_example()메소드의 인자로 넣어줘야 한다.
dataset=tf.data.TFRecordDataset(["my_contacts.tfrecord"]).batch(10)
for serialized_examples in dataset:
    parsed_examples=tf.io.parse_example(serialized_examples, feature_description)
  • parse_exmple()메소드를 사용하면 데이터를 하나씩 파싱하는 것이 아니라 배치 단위로 파싱하는 것이 가능해진다.

13.2.4 SequenceExample 프로토콜 버퍼를 사용해 리스트의 리스트 다루기

SequenceExample 프로토콜 버퍼의 정의


message FeatureList { repeated Feature feature=1;};
message FeatureLists { map<string, FeatureList> feature_list=1;};
message SequenceExample{
    Feature context = 1;
    FeatureLists feature_lists=2;

}

  • Features 객체는 문맥 데이터를 정의한다.
  • FeatureLists에는 한 개 이상의 FeatureList가 포함된다.
  • Feature 객체는 바이트 스트링의 리스트나 64비트 정수의 리스트, 실수의 리스트일 수 있다.

parsed_context, parsed_feature_lists=tf.io.parse_single_sequence_example(
    serialized_seqence_example, xontext_feature_descriptions,
    seqence_feature_descriptions)
parsed_context=tf.RaggedTensor.from_sparse(parsed_feature_lists["context"])
  • Sequence 프로토콜 버퍼를 파싱하는 방법은 Example 프로토콜 버퍼를 파싱하는 방법과 동일하다.