competition update
This commit is contained in:
96
language_model/runtime/core/utils/blocking_queue.h
Normal file
96
language_model/runtime/core/utils/blocking_queue.h
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef UTILS_BLOCKING_QUEUE_H_
|
||||
#define UTILS_BLOCKING_QUEUE_H_
|
||||
|
||||
#include <condition_variable>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
template <typename T>
|
||||
class BlockingQueue {
|
||||
public:
|
||||
explicit BlockingQueue(size_t capacity = std::numeric_limits<int>::max())
|
||||
: capacity_(capacity) {}
|
||||
|
||||
void Push(const T& value) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (queue_.size() >= capacity_) {
|
||||
not_full_condition_.wait(lock);
|
||||
}
|
||||
queue_.push(value);
|
||||
}
|
||||
not_empty_condition_.notify_one();
|
||||
}
|
||||
|
||||
void Push(T&& value) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (queue_.size() >= capacity_) {
|
||||
not_full_condition_.wait(lock);
|
||||
}
|
||||
queue_.push(std::move(value));
|
||||
}
|
||||
not_empty_condition_.notify_one();
|
||||
}
|
||||
|
||||
T Pop() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (queue_.empty()) {
|
||||
not_empty_condition_.wait(lock);
|
||||
}
|
||||
T t(std::move(queue_.front()));
|
||||
queue_.pop();
|
||||
not_full_condition_.notify_one();
|
||||
return t;
|
||||
}
|
||||
|
||||
bool Empty() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return queue_.empty();
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return queue_.size();
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
while (!Empty()) {
|
||||
Pop();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t capacity_;
|
||||
mutable std::mutex mutex_;
|
||||
std::condition_variable not_full_condition_;
|
||||
std::condition_variable not_empty_condition_;
|
||||
std::queue<T> queue_;
|
||||
|
||||
public:
|
||||
WENET_DISALLOW_COPY_AND_ASSIGN(BlockingQueue);
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // UTILS_BLOCKING_QUEUE_H_
|
||||
23
language_model/runtime/core/utils/flags.h
Normal file
23
language_model/runtime/core/utils/flags.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef UTILS_FLAGS_H_
|
||||
#define UTILS_FLAGS_H_
|
||||
|
||||
// Because openfst is a dynamic library compiled with gflags/glog, we must use
|
||||
// the gflags/glog from openfst to avoid them linked both statically and
|
||||
// dynamically into the executable.
|
||||
#include "fst/flags.h"
|
||||
|
||||
#endif // UTILS_FLAGS_H_
|
||||
23
language_model/runtime/core/utils/log.h
Normal file
23
language_model/runtime/core/utils/log.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef UTILS_LOG_H_
|
||||
#define UTILS_LOG_H_
|
||||
|
||||
// Because openfst is a dynamic library compiled with gflags/glog, we must use
|
||||
// the gflags/glog from openfst to avoid them linked both statically and
|
||||
// dynamically into the executable.
|
||||
#include "fst/log.h"
|
||||
|
||||
#endif // UTILS_LOG_H_
|
||||
191
language_model/runtime/core/utils/string.cc
Normal file
191
language_model/runtime/core/utils/string.cc
Normal file
@@ -0,0 +1,191 @@
|
||||
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "utils/string.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/log.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
void SplitString(const std::string& str, std::vector<std::string>* strs) {
|
||||
SplitStringToVector(str, " \t", true, strs);
|
||||
}
|
||||
|
||||
void SplitStringToVector(const std::string& full, const char* delim,
|
||||
bool omit_empty_strings,
|
||||
std::vector<std::string>* out) {
|
||||
size_t start = 0, found = 0, end = full.size();
|
||||
out->clear();
|
||||
while (found != std::string::npos) {
|
||||
found = full.find_first_of(delim, start);
|
||||
// start != end condition is for when the delimiter is at the end
|
||||
if (!omit_empty_strings || (found != start && start != end))
|
||||
out->push_back(full.substr(start, found - start));
|
||||
start = found + 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::string UTF8CodeToUTF8String(int code) {
|
||||
std::ostringstream ostr;
|
||||
if (code < 0) {
|
||||
LOG(ERROR) << "LabelsToUTF8String: Invalid character found: " << code;
|
||||
return ostr.str();
|
||||
} else if (code < 0x80) {
|
||||
ostr << static_cast<char>(code);
|
||||
} else if (code < 0x800) {
|
||||
ostr << static_cast<char>((code >> 6) | 0xc0);
|
||||
ostr << static_cast<char>((code & 0x3f) | 0x80);
|
||||
} else if (code < 0x10000) {
|
||||
ostr << static_cast<char>((code >> 12) | 0xe0);
|
||||
ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>((code & 0x3f) | 0x80);
|
||||
} else if (code < 0x200000) {
|
||||
ostr << static_cast<char>((code >> 18) | 0xf0);
|
||||
ostr << static_cast<char>(((code >> 12) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>((code & 0x3f) | 0x80);
|
||||
} else if (code < 0x4000000) {
|
||||
ostr << static_cast<char>((code >> 24) | 0xf8);
|
||||
ostr << static_cast<char>(((code >> 18) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>(((code >> 12) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>((code & 0x3f) | 0x80);
|
||||
} else {
|
||||
ostr << static_cast<char>((code >> 30) | 0xfc);
|
||||
ostr << static_cast<char>(((code >> 24) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>(((code >> 18) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>(((code >> 12) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80);
|
||||
ostr << static_cast<char>((code & 0x3f) | 0x80);
|
||||
}
|
||||
return ostr.str();
|
||||
}
|
||||
|
||||
// Split utf8 string into characters.
|
||||
bool SplitUTF8String(const std::string& str,
|
||||
std::vector<std::string>* characters) {
|
||||
const char* data = str.data();
|
||||
const size_t length = str.size();
|
||||
for (size_t i = 0; i < length; /* no update */) {
|
||||
int c = data[i++] & 0xff;
|
||||
if ((c & 0x80) == 0) {
|
||||
characters->push_back(UTF8CodeToUTF8String(c));
|
||||
} else {
|
||||
if ((c & 0xc0) == 0x80) {
|
||||
LOG(ERROR) << "UTF8StringToLabels: continuation byte as lead byte";
|
||||
return false;
|
||||
}
|
||||
int count =
|
||||
(c >= 0xc0) + (c >= 0xe0) + (c >= 0xf0) + (c >= 0xf8) + (c >= 0xfc);
|
||||
int code = c & ((1 << (6 - count)) - 1);
|
||||
while (count != 0) {
|
||||
if (i == length) {
|
||||
LOG(ERROR) << "UTF8StringToLabels: truncated utf-8 byte sequence";
|
||||
return false;
|
||||
}
|
||||
char cb = data[i++];
|
||||
if ((cb & 0xc0) != 0x80) {
|
||||
LOG(ERROR) << "UTF8StringToLabels: missing/invalid continuation byte";
|
||||
return false;
|
||||
}
|
||||
code = (code << 6) | (cb & 0x3f);
|
||||
count--;
|
||||
}
|
||||
if (code < 0) {
|
||||
// This should not be able to happen.
|
||||
LOG(ERROR) << "UTF8StringToLabels: Invalid character found: " << c;
|
||||
return false;
|
||||
}
|
||||
characters->push_back(UTF8CodeToUTF8String(code));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string ProcessBlank(const std::string& str) {
|
||||
std::string result;
|
||||
if (!str.empty()) {
|
||||
std::vector<std::string> characters;
|
||||
if (SplitUTF8String(str, &characters)) {
|
||||
for (std::string& character : characters) {
|
||||
if (character != kSpaceSymbol) {
|
||||
result.append(character);
|
||||
} else {
|
||||
// Ignore consecutive space or located in head
|
||||
if (!result.empty() && result.back() != ' ') {
|
||||
result.push_back(' ');
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ignore tailing space
|
||||
if (!result.empty() && result.back() == ' ') {
|
||||
result.pop_back();
|
||||
}
|
||||
for (size_t i = 0; i < result.size(); ++i) {
|
||||
result[i] = tolower(result[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void SplitEachChar(const std::string& word, std::vector<std::string>* chars) {
|
||||
chars->clear();
|
||||
size_t i = 0;
|
||||
while (i < word.length()) {
|
||||
assert((word[i] & 0xF8) <= 0xF0);
|
||||
int bytes_ = 1;
|
||||
if ((word[i] & 0x80) == 0x00) {
|
||||
// The first 128 characters (US-ASCII) in UTF-8 format only need one byte.
|
||||
bytes_ = 1;
|
||||
} else if ((word[i] & 0xE0) == 0xC0) {
|
||||
// The next 1,920 characters need two bytes to encode,
|
||||
// which covers the remainder of almost all Latin-script alphabets.
|
||||
bytes_ = 2;
|
||||
} else if ((word[i] & 0xF0) == 0xE0) {
|
||||
// Three bytes are needed for characters in the rest of
|
||||
// the Basic Multilingual Plane, which contains virtually all characters
|
||||
// in common use, including most Chinese, Japanese and Korean characters.
|
||||
bytes_ = 3;
|
||||
} else if ((word[i] & 0xF8) == 0xF0) {
|
||||
// Four bytes are needed for characters in the other planes of Unicode,
|
||||
// which include less common CJK characters, various historic scripts,
|
||||
// mathematical symbols, and emoji (pictographic symbols).
|
||||
bytes_ = 4;
|
||||
}
|
||||
chars->push_back(word.substr(i, bytes_));
|
||||
i += bytes_;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
bool CheckEnglishWord(const std::string& word) {
|
||||
std::vector<std::string> chars;
|
||||
SplitEachChar(word, &chars);
|
||||
for (size_t k = 0; k < chars.size(); k++) {
|
||||
// all english characters should be encoded in one byte
|
||||
if (chars[k].size() > 1) return false;
|
||||
// english words may contain apostrophe, i.e., "He's"
|
||||
if (chars[k][0] == '\'') continue;
|
||||
if (!isalpha(chars[k][0])) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace wenet
|
||||
45
language_model/runtime/core/utils/string.h
Normal file
45
language_model/runtime/core/utils/string.h
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef UTILS_STRING_H_
|
||||
#define UTILS_STRING_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace wenet {
|
||||
|
||||
void SplitString(const std::string& str, std::vector<std::string>* strs);
|
||||
|
||||
void SplitStringToVector(const std::string& full, const char* delim,
|
||||
bool omit_empty_strings,
|
||||
std::vector<std::string>* out);
|
||||
|
||||
bool SplitUTF8String(const std::string& str,
|
||||
std::vector<std::string>* characters);
|
||||
|
||||
// Remove head,tail and consecutive space.
|
||||
std::string ProcessBlank(const std::string& str);
|
||||
|
||||
// NOTE(Xingchen Song): we add this function to make it possible to
|
||||
// support multilingual recipe in the future, in which characters of
|
||||
// different languages are all encoded in UTF-8 format.
|
||||
// UTF-8 REF: https://en.wikipedia.org/wiki/UTF-8#Encoding
|
||||
void SplitEachChar(const std::string& word, std::vector<std::string>* chars);
|
||||
|
||||
bool CheckEnglishWord(const std::string& word);
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // UTILS_STRING_H_
|
||||
39
language_model/runtime/core/utils/timer.h
Normal file
39
language_model/runtime/core/utils/timer.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef UTILS_TIMER_H_
|
||||
#define UTILS_TIMER_H_
|
||||
|
||||
#include <chrono>
|
||||
|
||||
namespace wenet {
|
||||
|
||||
class Timer {
|
||||
public:
|
||||
Timer() : time_start_(std::chrono::steady_clock::now()) {}
|
||||
void Reset() { time_start_ = std::chrono::steady_clock::now(); }
|
||||
// return int in milliseconds
|
||||
int Elapsed() const {
|
||||
auto time_now = std::chrono::steady_clock::now();
|
||||
return std::chrono::duration_cast<std::chrono::milliseconds>(time_now -
|
||||
time_start_)
|
||||
.count();
|
||||
}
|
||||
|
||||
private:
|
||||
std::chrono::time_point<std::chrono::steady_clock> time_start_;
|
||||
};
|
||||
} // namespace wenet
|
||||
|
||||
#endif // UTILS_TIMER_H_
|
||||
32
language_model/runtime/core/utils/utils.cc
Normal file
32
language_model/runtime/core/utils/utils.cc
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "utils/utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "utils/log.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
float LogAdd(float x, float y) {
|
||||
static float num_min = -std::numeric_limits<float>::max();
|
||||
if (x <= num_min) return y;
|
||||
if (y <= num_min) return x;
|
||||
float xmax = std::max(x, y);
|
||||
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
|
||||
}
|
||||
|
||||
} // namespace wenet
|
||||
34
language_model/runtime/core/utils/utils.h
Normal file
34
language_model/runtime/core/utils/utils.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef UTILS_UTILS_H_
|
||||
#define UTILS_UTILS_H_
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace wenet {
|
||||
|
||||
#define WENET_DISALLOW_COPY_AND_ASSIGN(Type) \
|
||||
Type(const Type &) = delete; \
|
||||
Type &operator=(const Type &) = delete;
|
||||
|
||||
const float kFloatMax = std::numeric_limits<float>::max();
|
||||
const char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||
|
||||
// Return the sum of two probabilities in log scale
|
||||
float LogAdd(float x, float y);
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // UTILS_UTILS_H_
|
||||
Reference in New Issue
Block a user