Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| // Copyright 2016 Google Inc. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License.! | |
| ABSL_FLAG(std::string, model, "", "model file name"); | |
| ABSL_FLAG(std::string, input, "", "input filename"); | |
| ABSL_FLAG(std::string, output, "", "output filename"); | |
| ABSL_FLAG(std::string, input_format, "piece", "choose from piece or id"); | |
| ABSL_FLAG(std::string, output_format, "string", "choose from string or proto"); | |
| ABSL_FLAG(std::string, extra_options, "", | |
| "':' separated encoder extra options, e.g., \"reverse:bos:eos\""); | |
| int main(int argc, char *argv[]) { | |
| sentencepiece::ScopedResourceDestructor cleaner; | |
| sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true); | |
| std::vector<std::string> rest_args; | |
| if (absl::GetFlag(FLAGS_input).empty()) { | |
| for (int i = 1; i < argc; ++i) { | |
| rest_args.push_back(std::string(argv[i])); | |
| } | |
| } else { | |
| rest_args.push_back(absl::GetFlag(FLAGS_input)); | |
| } | |
| if (rest_args.empty()) | |
| rest_args.push_back(""); // empty means that reading from stdin. | |
| CHECK(!absl::GetFlag(FLAGS_model).empty()); | |
| sentencepiece::SentencePieceProcessor sp; | |
| CHECK_OK(sp.Load(absl::GetFlag(FLAGS_model))); | |
| CHECK_OK(sp.SetDecodeExtraOptions(absl::GetFlag(FLAGS_extra_options))); | |
| auto output = | |
| sentencepiece::filesystem::NewWritableFile(absl::GetFlag(FLAGS_output)); | |
| CHECK_OK(output->status()); | |
| std::string detok, line; | |
| sentencepiece::SentencePieceText spt; | |
| std::function<void(const std::vector<std::string> &pieces)> process; | |
| auto ToIds = [&](const std::vector<std::string> &pieces) { | |
| std::vector<int> ids; | |
| ids.reserve(pieces.size()); | |
| for (const auto &s : pieces) { | |
| ids.push_back(atoi(s.c_str())); | |
| } | |
| return ids; | |
| }; | |
| if (absl::GetFlag(FLAGS_input_format) == "piece") { | |
| if (absl::GetFlag(FLAGS_output_format) == "string") { | |
| process = [&](const std::vector<std::string> &pieces) { | |
| CHECK_OK(sp.Decode(pieces, &detok)); | |
| output->WriteLine(detok); | |
| }; | |
| } else if (absl::GetFlag(FLAGS_output_format) == "proto") { | |
| process = [&](const std::vector<std::string> &pieces) { | |
| CHECK_OK(sp.Decode(pieces, &spt)); | |
| }; | |
| } else { | |
| LOG(FATAL) << "Unknown output format: " | |
| << absl::GetFlag(FLAGS_output_format); | |
| } | |
| } else if (absl::GetFlag(FLAGS_input_format) == "id") { | |
| if (absl::GetFlag(FLAGS_output_format) == "string") { | |
| process = [&](const std::vector<std::string> &pieces) { | |
| CHECK_OK(sp.Decode(ToIds(pieces), &detok)); | |
| output->WriteLine(detok); | |
| }; | |
| } else if (absl::GetFlag(FLAGS_output_format) == "proto") { | |
| process = [&](const std::vector<std::string> &pieces) { | |
| CHECK_OK(sp.Decode(ToIds(pieces), &spt)); | |
| }; | |
| } else { | |
| LOG(FATAL) << "Unknown output format: " | |
| << absl::GetFlag(FLAGS_output_format); | |
| } | |
| } else { | |
| LOG(FATAL) << "Unknown input format: " << absl::GetFlag(FLAGS_input_format); | |
| } | |
| for (const auto &filename : rest_args) { | |
| auto input = sentencepiece::filesystem::NewReadableFile(filename); | |
| CHECK_OK(input->status()); | |
| while (input->ReadLine(&line)) { | |
| const auto pieces = absl::StrSplit(line, " "); | |
| process(pieces); | |
| } | |
| } | |
| return 0; | |
| } | |
