Commit ee180e5e by Jan Wijffels

All

parent c345122d
^.*\.Rproj$
^\.Rproj\.user$
Package: textspace
Type: Package
Title: What the Package Does (Title Case)
Version: 0.1.0
Author: Who wrote it
Maintainer: The package maintainer <yourself@somewhere.net>
Description: More about what it does (maybe more than one line)
Use four spaces when indenting paragraphs within the Description.
License: What license is it under?
Encoding: UTF-8
LazyData: true
Imports: Rcpp (>= 0.12.14)
LinkingTo: Rcpp, BH
RoxygenNote: 6.0.1
SystemRequirements: C++11
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
rcpp_hello <- function() {
.Call('_textspace_rcpp_hello', PACKAGE = 'textspace')
}
# Hello, world!
#
# This is an example function named 'hello'
# which prints 'Hello, world!'.
#
# You can learn more about package authoring with RStudio at:
#
# http://r-pkgs.had.co.nz/
#
# Some useful keyboard shortcuts for package authoring:
#
# Build and Reload Package: 'Ctrl + Shift + B'
# Check Package: 'Ctrl + Shift + E'
# Test Package: 'Ctrl + Shift + T'
hello <- function() {
print("Hello, world!")
}
// Generated by using Rcpp::compileAttributes() -> do not edit by hand
// Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
#include <Rcpp.h>
using namespace Rcpp;
// rcpp_hello
List rcpp_hello();
RcppExport SEXP _textspace_rcpp_hello() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(rcpp_hello());
return rcpp_result_gen;
END_RCPP
}
static const R_CallMethodDef CallEntries[] = {
{"_textspace_rcpp_hello", (DL_FUNC) &_textspace_rcpp_hello, 0},
{NULL, NULL, 0}
};
RcppExport void R_init_textspace(DllInfo *dll) {
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
R_useDynamicSymbols(dll, FALSE);
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "../starspace.h"
#include <iostream>
#include <boost/algorithm/string/predicate.hpp>
using namespace std;
using namespace starspace;
// Read each sentence / document line by line,
// and output it's embedding vector
void embedDoc(StarSpace& sp, istream& fin) {
string input;
while (getline(fin, input)) {
if (input.size() ==0) break;
cout << input << endl;
auto vec = sp.getDocVector(input);
vec.forEachCell([&](Real r) { cout << r << ' '; });
cout << endl;
}
}
int main(int argc, char** argv) {
shared_ptr<Args> args = make_shared<Args>();
if (argc < 2) {
cerr << "usage: " << argv[0] << " <model> [filename]\n";
cerr << "if filename is specified, it reads each line from the file and"
<< "output corresponding vectors";
return 1;
}
std::string model(argv[1]);
args->model = model;
StarSpace sp(args);
sp.initFromSavedModel(args->model);
// set useWeight by default.
// use 1.0 for default weight if weight is not found
args->useWeight = true;
if (argc > 2) {
std::string filename(argv[2]);
ifstream fin(filename);
if (!fin.is_open()) {
std::cerr << "file cannot be opened for loading!" << std::endl;
exit(EXIT_FAILURE);
}
embedDoc(sp, fin);
fin.close();
} else {
cout << "Input your sentence / document now:\n";
embedDoc(sp, cin);
}
return 0;
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "../starspace.h"
#include <iostream>
#include <boost/algorithm/string/predicate.hpp>
using namespace std;
using namespace starspace;
int main(int argc, char** argv) {
shared_ptr<Args> args = make_shared<Args>();
if (argc < 2) {
cerr << "usage: " << argv[0] << " <model> [k]\n";
return 1;
}
std::string model(argv[1]);
args->model = model;
StarSpace sp(args);
sp.initFromSavedModel(args->model);
if (args->ngrams == 1) {
std::cerr << "Error: your provided model does not use ngram.\n";
exit(EXIT_FAILURE);
}
string input;
while (getline(cin, input)) {
auto vec = sp.getNgramVector(input);
cout << input;
for (auto v : vec) { cout << "\t" << v; }
cout << endl;
}
return 0;
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "../starspace.h"
#include <iostream>
#include <boost/algorithm/string/predicate.hpp>
using namespace std;
using namespace starspace;
int main(int argc, char** argv) {
shared_ptr<Args> args = make_shared<Args>();
if (argc < 2) {
cerr << "usage: " << argv[0] << " <model> [k]\n";
return 1;
}
std::string model(argv[1]);
args->model = model;
int k = 5;
if (argc > 2) {
k = atoi(argv[2]);
}
StarSpace sp(args);
if (boost::algorithm::ends_with(args->model, ".tsv")) {
sp.initFromTsv(args->model);
} else {
sp.initFromSavedModel(args->model);
}
cout << "------Loaded model args:\n";
args->printArgs();
for(;;) {
string input;
cout << "Enter some text: ";
if (!getline(cin, input) || input.size() == 0) break;
sp.nearestNeighbor(input, k);
}
return 0;
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "../starspace.h"
#include <iostream>
#include <boost/algorithm/string/predicate.hpp>
using namespace std;
using namespace starspace;
int main(int argc, char** argv) {
shared_ptr<Args> args = make_shared<Args>();
if (argc < 3) {
cerr << "usage: " << argv[0] << " <model> k [basedoc]\n";
return 1;
}
std::string model(argv[1]);
args->K = atoi(argv[2]);
args->model = model;
if (argc > 3) {
args->fileFormat = "labelDoc";
args->basedoc = argv[3];
}
StarSpace sp(args);
if (boost::algorithm::ends_with(args->model, ".tsv")) {
sp.initFromTsv(args->model);
} else {
sp.initFromSavedModel(args->model);
cout << "------Loaded model args:\n";
args->printArgs();
}
// Set dropout probability to 0 in test case.
sp.args_->dropoutLHS = 0.0;
sp.args_->dropoutRHS = 0.0;
// Load basedocs which are set of possible things to predict.
sp.loadBaseDocs();
for(;;) {
string input;
cout << "Enter some text: ";
if (!getline(cin, input) || input.size() == 0) break;
// Do the prediction
vector<Base> query_vec;
sp.parseDoc(input, query_vec, " ");
vector<Predictions> predictions;
sp.predictOne(query_vec, predictions);
for (int i = 0; i < predictions.size(); i++) {
cout << i << "[" << predictions[i].first << "]: ";
sp.printDoc(cout, sp.baseDocs_[predictions[i].second]);
}
cout << "\n";
}
return 0;
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "data.h"
#include "utils/utils.h"
#include <string>
#include <vector>
#include <fstream>
#include <assert.h>
using namespace std;
namespace starspace {
InternDataHandler::InternDataHandler(shared_ptr<Args> args) {
size_ = 0;
idx_ = -1;
examples_.clear();
args_= args;
}
void InternDataHandler::errorOnZeroExample(const string& fileName) {
std::cerr << "ERROR: File '" << fileName
<< "' does not contain any valid example.\n"
<< "Please check: is the file empty? "
<< "Do the examples contain proper feature and label according to the trainMode? "
<< "If your examples are unlabeled, try to set trainMode=5.\n";
exit(EXIT_FAILURE);
}
void InternDataHandler::loadFromFile(
const string& fileName,
shared_ptr<DataParser> parser) {
ifstream fin(fileName);
if (!fin.is_open()) {
std::cerr << fileName << " cannot be opened for loading!" << std::endl;
exit(EXIT_FAILURE);
}
fin.close();
cout << "Loading data from file : " << fileName << endl;
vector<Corpus> corpora(args_->thread);
foreach_line(
fileName,
[&](std::string& line) {
auto& corpus = corpora[getThreadID()];
ParseResults example;
if (parser->parse(line, example)) {
corpus.push_back(example);
}
},
args_->thread
);
// Glue corpora together.
auto totalSize = std::accumulate(corpora.begin(), corpora.end(), size_t(0),
[](size_t l, Corpus& r) { return l + r.size(); });
size_t destCursor = examples_.size();
examples_.resize(totalSize + examples_.size());
for (const auto &subcorp: corpora) {
std::copy(subcorp.begin(), subcorp.end(), examples_.begin() + destCursor);
destCursor += subcorp.size();
}
cout << "Total number of examples loaded : " << examples_.size() << endl;
size_ = examples_.size();
if (size_ == 0) {
errorOnZeroExample(fileName);
}
}
// Convert an example for training/testing if needed.
// In the case of trainMode=1, a random label from r.h.s will be selected
// as label, and the rest of labels from r.h.s. will be input features
void InternDataHandler::convert(
const ParseResults& example,
ParseResults& rslt) const {
rslt.weight = example.weight;
rslt.LHSTokens.clear();
rslt.RHSTokens.clear();
rslt.LHSTokens.insert(rslt.LHSTokens.end(),
example.LHSTokens.begin(), example.LHSTokens.end());
if (args_->trainMode == 0) {
// lhs is the same, pick one random label as rhs
assert(example.LHSTokens.size() > 0);
assert(example.RHSTokens.size() > 0);
auto idx = rand() % example.RHSTokens.size();
rslt.RHSTokens.push_back(example.RHSTokens[idx]);
} else {
assert(example.RHSTokens.size() > 1);
if (args_->trainMode == 1) {
// pick one random label as rhs and the rest is lhs
auto idx = rand() % example.RHSTokens.size();
for (int i = 0; i < example.RHSTokens.size(); i++) {
auto tok = example.RHSTokens[i];
if (i == idx) {
rslt.RHSTokens.push_back(tok);
} else {
rslt.LHSTokens.push_back(tok);
}
}
} else
if (args_->trainMode == 2) {
// pick one random label as lhs and the rest is rhs
auto idx = rand() % example.RHSTokens.size();
for (int i = 0; i < example.RHSTokens.size(); i++) {
auto tok = example.RHSTokens[i];
if (i == idx) {
rslt.LHSTokens.push_back(tok);
} else {
rslt.RHSTokens.push_back(tok);
}
}
} else
if (args_->trainMode == 3) {
// pick two random labels, one as lhs and the other as rhs
auto idx = rand() % example.RHSTokens.size();
int idx2;
do {
idx2 = rand() % example.RHSTokens.size();
} while (idx2 == idx);
rslt.LHSTokens.push_back(example.RHSTokens[idx]);
rslt.RHSTokens.push_back(example.RHSTokens[idx2]);
} else
if (args_->trainMode == 4) {
// the first one as lhs and the second one as rhs
rslt.LHSTokens.push_back(example.RHSTokens[0]);
rslt.RHSTokens.push_back(example.RHSTokens[1]);
}
}
}
void InternDataHandler::getWordExamples(
const vector<Base>& doc,
vector<ParseResults>& rslts) const {
rslts.clear();
for (int widx = 0; widx < doc.size(); widx++) {
ParseResults rslt;
rslt.LHSTokens.clear();
rslt.RHSTokens.clear();
rslt.RHSTokens.push_back(doc[widx]);
for (int i = max(widx - args_->ws, 0);
i < min(size_t(widx + args_->ws), doc.size()); i++) {
if (i != widx) {
rslt.LHSTokens.push_back(doc[i]);
}
}
rslt.weight = args_->wordWeight;
rslts.emplace_back(rslt);
}
}
void InternDataHandler::getWordExamples(
int idx,
vector<ParseResults>& rslts) const {
assert(idx < size_);
const auto& example = examples_[idx];
getWordExamples(example.LHSTokens, rslts);
}
void InternDataHandler::addExample(const ParseResults& example) {
examples_.push_back(example);
size_++;
}
void InternDataHandler::getExampleById(int32_t idx, ParseResults& rslt) const {
assert(idx < size_);
convert(examples_[idx], rslt);
}
void InternDataHandler::getNextExample(ParseResults& rslt) {
assert(size_ > 0);
idx_ = idx_ + 1;
// go back to the beginning of the examples if we reach the end
if (idx_ >= size_) {
idx_ = idx_ - size_;
}
convert(examples_[idx_], rslt);
}
void InternDataHandler::getRandomExample(ParseResults& rslt) const {
assert(size_ > 0);
int32_t idx = rand() % size_;
convert(examples_[idx], rslt);
}
void InternDataHandler::getKRandomExamples(int K, vector<ParseResults>& c) {
auto kSamples = min(K, size_);
for (int i = 0; i < kSamples; i++) {
ParseResults example;
getRandomExample(example);
c.push_back(example);
}
}
void InternDataHandler::getNextKExamples(int K, vector<ParseResults>& c) {
auto kSamples = min(K, size_);
for (int i = 0; i < kSamples; i++) {
idx_ = (idx_ + 1) % size_;
ParseResults example;
convert(examples_[idx_], example);
c.push_back(example);
}
}
// Randomly sample one example and randomly sample a label from this example
// The result is usually used as negative samples in training
void InternDataHandler::getRandomRHS(vector<Base>& results, bool trainWord) const {
assert(size_ > 0);
results.clear();
auto& ex = examples_[rand() % size_];
if (args_->trainMode == 5 || trainWord) {
int r = rand() % ex.LHSTokens.size();
results.push_back(ex.LHSTokens[r]);
} else {
int r = rand() % ex.RHSTokens.size();
if (args_->trainMode == 2) {
for (int i = 0; i < ex.RHSTokens.size(); i++) {
if (i != r) {
results.push_back(ex.RHSTokens[i]);
}
}
} else {
results.push_back(ex.RHSTokens[r]);
}
}
}
void InternDataHandler::save(std::ostream& out) {
out << "data size : " << size_ << endl;
for (auto& example : examples_) {
out << "lhs : ";
for (auto t : example.LHSTokens) {out << t.first << ':' << t.second << ' ';}
out << endl;
out << "rhs : ";
for (auto t : example.RHSTokens) {out << t.first << ':' << t.second << ' ';}
out << endl;
}
}
} // unamespace starspace
// Copyright 2004-, Facebook, Inc. All Rights Reserved.
/* This is the basic class of internal data handler.
* It loads data from file and stores it in internal format for easy access
* at training/testing time.
*
* It also provides random RHS sampling for negative sampling in training.
*/
#pragma once
#include "dict.h"
#include "parser.h"
#include <string>
#include <vector>
#include <fstream>
namespace starspace {
class InternDataHandler {
public:
explicit InternDataHandler(std::shared_ptr<Args> args);
virtual void loadFromFile(const std::string& file,
std::shared_ptr<DataParser> parser);
virtual void convert(const ParseResults& example, ParseResults& rslt) const;
virtual void getRandomRHS(std::vector<Base>& results, bool trainWord = false)
const;
virtual void save(std::ostream& out);
virtual void getWordExamples(int idx, std::vector<ParseResults>& rslt) const;
void getWordExamples(
const std::vector<Base>& doc,
std::vector<ParseResults>& rslt) const;
void addExample(const ParseResults& example);
void getExampleById(int32_t idx, ParseResults& rslt) const;
void getNextExample(ParseResults& rslt);
void getRandomExample(ParseResults& rslt) const;
void getKRandomExamples(int K, std::vector<ParseResults>& c);
void getNextKExamples(int K, std::vector<ParseResults>& c);
size_t getSize() const { return size_; };
void errorOnZeroExample(const std::string& fileName);
protected:
static const int32_t MAX_VOCAB_SIZE = 10000000;
std::shared_ptr<Args> args_;
std::vector<ParseResults> examples_;
int32_t idx_ = -1;
int32_t size_ = 0;
};
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "dict.h"
#include "parser.h"
#include <assert.h>
#include <algorithm>
#include <iterator>
#include <cmath>
#include <fstream>
#include <sstream>
using namespace std;
namespace starspace {
const std::string Dictionary::EOS = "</s>";
const uint32_t Dictionary::HASH_C = 116049371;
Dictionary::Dictionary(shared_ptr<Args> args) : args_(args),
hashToIndex_(MAX_VOCAB_SIZE, -1), size_(0), nwords_(0), nlabels_(0),
ntokens_(0)
{
entryList_.clear();
}
// hash trick from fastText
uint32_t Dictionary::hash(const std::string& str) const {
uint32_t h = 2166136261;
for (size_t i = 0; i < str.size(); i++) {
h = h ^ uint32_t(str[i]);
h = h * 16777619;
}
return h;
}
int32_t Dictionary::find(const std::string& w) const {
int32_t h = hash(w) % MAX_VOCAB_SIZE;
while (hashToIndex_[h] != -1 && entryList_[hashToIndex_[h]].symbol != w) {
h = (h + 1) % MAX_VOCAB_SIZE;
}
return h;
}
int32_t Dictionary::getId(const string& symbol) const {
int32_t h = find(symbol);
return hashToIndex_[h];
}
const std::string& Dictionary::getSymbol(int32_t id) const {
assert(id >= 0);
assert(id < size_);
return entryList_[id].symbol;
}
const std::string& Dictionary::getLabel(int32_t lid) const {
assert(lid >= 0);
assert(lid < nlabels_);
return entryList_[lid + nwords_].symbol;
}
entry_type Dictionary::getType(int32_t id) const {
assert(id >= 0);
assert(id < size_);
return entryList_[id].type;
}
entry_type Dictionary::getType(const string& w) const {
return (w.find(args_->label) == 0)? entry_type::label : entry_type::word;
}
void Dictionary::insert(const string& symbol) {
int32_t h = find(symbol);
ntokens_++;
if (hashToIndex_[h] == -1) {
entry e;
e.symbol = symbol;
e.count = 1;
e.type = getType(symbol);
entryList_.push_back(e);
hashToIndex_[h] = size_++;
} else {
entryList_[hashToIndex_[h]].count++;
}
}
void Dictionary::save(std::ostream& out) const {
out.write((char*) &size_, sizeof(int32_t));
out.write((char*) &nwords_, sizeof(int32_t));
out.write((char*) &nlabels_, sizeof(int32_t));
out.write((char*) &ntokens_, sizeof(int64_t));
for (int32_t i = 0; i < size_; i++) {
entry e = entryList_[i];
out.write(e.symbol.data(), e.symbol.size() * sizeof(char));
out.put(0);
out.write((char*) &(e.count), sizeof(int64_t));
out.write((char*) &(e.type), sizeof(entry_type));
}
}
void Dictionary::load(std::istream& in) {
entryList_.clear();
std::fill(hashToIndex_.begin(), hashToIndex_.end(), -1);
in.read((char*) &size_, sizeof(int32_t));
in.read((char*) &nwords_, sizeof(int32_t));
in.read((char*) &nlabels_, sizeof(int32_t));
in.read((char*) &ntokens_, sizeof(int64_t));
for (int32_t i = 0; i < size_; i++) {
char c;
entry e;
while ((c = in.get()) != 0) {
e.symbol.push_back(c);
}
in.read((char*) &e.count, sizeof(int64_t));
in.read((char*) &e.type, sizeof(entry_type));
entryList_.push_back(e);
hashToIndex_[find(e.symbol)] = i;
}
}
/* Build dictionary from file.
* In dictionary building process, if the current dictionary is at 75% capacity,
* it automatically increases the threshold for both word and label.
* At the end the -minCount and -minCountLabel from arguments will be applied
* as thresholds.
*/
void Dictionary::readFromFile(
const std::string& file,
shared_ptr<DataParser> parser) {
cout << "Build dict from input file : " << file << endl;
ifstream fin(file);
if (!fin.is_open()) {
cerr << "Input file cannot be opened!" << endl;
exit(EXIT_FAILURE);
}
int64_t minThreshold = 1;
size_t lines_read = 0;
std::string line;
while (getline(fin, line)) {
vector<string> tokens;
parser->parseForDict(line, tokens);
lines_read++;
for (auto token : tokens) {
insert(token);
if ((ntokens_ % 1000000 == 0) && args_->verbose) {
std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush;
}
if (size_ > 0.75 * MAX_VOCAB_SIZE) {
minThreshold++;
threshold(minThreshold, minThreshold);
}
}
}
fin.close();
threshold(args_->minCount, args_->minCountLabel);
std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl;
std::cerr << "Number of words in dictionary: " << nwords_ << std::endl;
std::cerr << "Number of labels in dictionary: " << nlabels_ << std::endl;
if (lines_read == 0) {
std::cerr << "ERROR: Empty file." << std::endl;
exit(EXIT_FAILURE);
}
if (size_ == 0) {
std::cerr << "Empty vocabulary. Try a smaller -minCount value."
<< std::endl;
exit(EXIT_FAILURE);
}
}
// Sort the dictionary by [word, label] order and by number of occurance.
// Removes word / label that does not pass respective threshold.
void Dictionary::threshold(int64_t t, int64_t tl) {
sort(entryList_.begin(), entryList_.end(), [](const entry& e1, const entry& e2) {
if (e1.type != e2.type) return e1.type < e2.type;
return e1.count > e2.count;
});
entryList_.erase(remove_if(entryList_.begin(), entryList_.end(), [&](const entry& e) {
return (e.type == entry_type::word && e.count < t) ||
(e.type == entry_type::label && e.count < tl);
}), entryList_.end());
entryList_.shrink_to_fit();
computeCounts();
}
void Dictionary::computeCounts() {
size_ = 0;
nwords_ = 0;
nlabels_ = 0;
std::fill(hashToIndex_.begin(), hashToIndex_.end(), -1);
for (auto it = entryList_.begin(); it != entryList_.end(); ++it) {
int32_t h = find(it->symbol);
hashToIndex_[h] = size_++;
if (it->type == entry_type::word) nwords_++;
if (it->type == entry_type::label) nlabels_++;
}
}
// Given a model saved in .tsv format, build the dictionary from model.
void Dictionary::loadDictFromModel(const string& modelfile) {
cout << "Loading dict from model file : " << modelfile << endl;
ifstream fin(modelfile);
string line;
while (getline(fin, line)) {
string symbol;
stringstream ss(line);
ss >> symbol;
insert(symbol);
}
fin.close();
computeCounts();
std::cout << "Number of words in dictionary: " << nwords_ << std::endl;
std::cout << "Number of labels in dictionary: " << nlabels_ << std::endl;
}
} // namespace
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
/**
* The implementation of dictionary here is very similar to the dictionary used
* in fastText (https://github.com/facebookresearch/fastText).
*/
#pragma once
#include "utils/args.h"
#include <vector>
#include <string>
#include <unordered_map>
#include <iostream>
#include <random>
#include <memory>
namespace starspace {
class DataParser;
enum class entry_type : int8_t {word=0, label=1};
struct entry {
std::string symbol;
int64_t count;
entry_type type;
};
class Dictionary {
public:
static const std::string EOS;
static const uint32_t HASH_C;
explicit Dictionary(std::shared_ptr<Args>);
int32_t size() const { return size_; };
int32_t nwords() const { return nwords_; };
int32_t nlabels() const { return nlabels_; };
int32_t ntokens() const { return ntokens_; };
int32_t getId(const std::string&) const;
entry_type getType(int32_t) const;
entry_type getType(const std::string&) const;
const std::string& getSymbol(int32_t) const;
const std::string& getLabel(int32_t) const;
uint32_t hash(const std::string& str) const;
void insert(const std::string&);
void load(std::istream&);
void save(std::ostream&) const;
void readFromFile(const std::string&, std::shared_ptr<DataParser>);
bool readWord(std::istream&, std::string&) const;
void threshold(int64_t, int64_t);
void computeCounts();
void loadDictFromModel(const std::string& model);
private:
static const int32_t MAX_VOCAB_SIZE = 30000000;
int32_t find(const std::string&) const;
void addNgrams(
std::vector<int32_t>& line,
const std::vector<int32_t>& hashes,
int32_t n) const;
std::shared_ptr<Args> args_;
std::vector<entry> entryList_;
std::vector<int32_t> hashToIndex_;
int32_t size_;
int32_t nwords_;
int32_t nlabels_;
int64_t ntokens_;
};
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "doc_data.h"
#include "utils/utils.h"
#include <string>
#include <vector>
#include <fstream>
#include <assert.h>
using namespace std;
namespace starspace {
LayerDataHandler::LayerDataHandler(shared_ptr<Args> args) :
InternDataHandler(args) {
}
void LayerDataHandler::loadFromFile(
const string& fileName,
shared_ptr<DataParser> parser) {
ifstream fin(fileName);
if (!fin.is_open()) {
std::cerr << fileName << " cannot be opened for loading!" << std::endl;
exit(EXIT_FAILURE);
}
fin.close();
cout << "Loading data from file : " << fileName << endl;
vector<Corpus> corpora(args_->thread);
foreach_line(
fileName,
[&](std::string& line) {
auto& corpus = corpora[getThreadID()];
ParseResults example;
if (parser->parse(line, example)) {
corpus.push_back(example);
}
},
args_->thread
);
// Glue corpora together.
auto totalSize = std::accumulate(corpora.begin(), corpora.end(), size_t(0),
[](size_t l, Corpus& r) { return l + r.size(); });
size_t destCursor = examples_.size();
examples_.resize(totalSize + examples_.size());
for (const auto &subcorp: corpora) {
std::copy(subcorp.begin(), subcorp.end(), examples_.begin() + destCursor);
destCursor += subcorp.size();
}
cout << "Total number of examples loaded : " << examples_.size() << endl;
size_ = examples_.size();
if (size_ == 0) {
errorOnZeroExample(fileName);
}
}
void LayerDataHandler::insert(
vector<Base>& rslt,
const vector<Base>& ex,
float dropout) const {
if (dropout < 1e-8) {
// if dropout is not enabled, copy all elements
rslt.insert(rslt.end(), ex.begin(), ex.end());
} else {
// dropout enabled
auto rnd = [&] {
static __thread unsigned int rState;
return rand_r(&rState);
};
for (const auto& it : ex) {
auto p = (double)(rnd()) / RAND_MAX;
if (p > dropout) {
rslt.push_back(it);
}
}
}
}
void LayerDataHandler::getWordExamples(
int idx,
vector<ParseResults>& rslts) const {
assert(idx < size_);
const auto& example = examples_[idx];
assert(example.RHSFeatures.size() > 0);
// take one random sentence and train on word
auto r = rand() % example.RHSFeatures.size();
InternDataHandler::getWordExamples(example.RHSFeatures[r], rslts);
}
void LayerDataHandler::convert(
const ParseResults& example,
ParseResults& rslt) const {
rslt.weight = example.weight;
rslt.LHSTokens.clear();
rslt.RHSTokens.clear();
if (args_->trainMode == 0) {
assert(example.LHSTokens.size() > 0);
assert(example.RHSFeatures.size() > 0);
insert(rslt.LHSTokens, example.LHSTokens, args_->dropoutLHS);
auto idx = rand() % example.RHSFeatures.size();
insert(rslt.RHSTokens, example.RHSFeatures[idx], args_->dropoutRHS);
} else {
assert(example.RHSFeatures.size() > 1);
if (args_->trainMode == 1) {
// pick one random rhs as label, the rest becomes lhs features
auto idx = rand() % example.RHSFeatures.size();
for (int i = 0; i < example.RHSFeatures.size(); i++) {
if (i == idx) {
insert(rslt.RHSTokens, example.RHSFeatures[i], args_->dropoutRHS);
} else {
insert(rslt.LHSTokens, example.RHSFeatures[i], args_->dropoutLHS);
}
}
} else
if (args_->trainMode == 2) {
// pick one random rhs as lhs, the rest becomes rhs features
auto idx = rand() % example.RHSFeatures.size();
for (int i = 0; i < example.RHSFeatures.size(); i++) {
if (i == idx) {
insert(rslt.LHSTokens, example.RHSFeatures[i], args_->dropoutLHS);
} else {
insert(rslt.RHSTokens, example.RHSFeatures[i], args_->dropoutRHS);
}
}
} else
if (args_->trainMode == 3) {
// pick one random rhs as input
auto idx = rand() % example.RHSFeatures.size();
insert(rslt.LHSTokens, example.RHSFeatures[idx], args_->dropoutLHS);
// pick another random rhs as label
int idx2;
do {
idx2 = rand() % example.RHSFeatures.size();
} while (idx == idx2);
insert(rslt.RHSTokens, example.RHSFeatures[idx2], args_->dropoutRHS);
} else
if (args_->trainMode == 4) {
// the first one as lhs and the second one as rhs
insert(rslt.LHSTokens, example.RHSFeatures[0], args_->dropoutLHS);
insert(rslt.RHSTokens, example.RHSFeatures[1], args_->dropoutRHS);
}
}
}
void LayerDataHandler::getRandomRHS(vector<Base>& result, bool trainWord) const {
assert(size_ > 0);
auto& ex = examples_[rand() % size_];
int r = rand() % ex.RHSFeatures.size();
result.clear();
if (args_->trainMode == 5 || trainWord) {
// pick random word
int wid = rand() % ex.RHSFeatures[r].size();
result.push_back(ex.RHSFeatures[r][wid]);
} else if (args_->trainMode == 2) {
// pick one random, the rest is rhs features
for (int i = 0; i < ex.RHSFeatures.size(); i++) {
if (i != r) {
insert(result, ex.RHSFeatures[i], args_->dropoutRHS);
}
}
} else {
insert(result, ex.RHSFeatures[r], args_->dropoutRHS);
}
}
void LayerDataHandler::save(ostream& out) {
for (auto example : examples_) {
out << "lhs: ";
for (auto t : example.LHSTokens) {
out << t.first << ':' << t.second << ' ';
}
out << "\nrhs: ";
for (auto feat : example.RHSFeatures) {
for (auto r : feat) { cout << r.first << ':' << r.second << ' '; }
out << "\t";
}
out << endl;
}
}
} // namespace starspace
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
/**
* This is the internal data handler class for the case where we
* have features to represent labels. It overrides a few key functions
* in DataHandler class in order to return label features for training/testing
* instead of label ids.
*/
#pragma once
#include "dict.h"
#include "data.h"
#include "doc_parser.h"
#include <string>
#include <vector>
#include <fstream>
namespace starspace {
class LayerDataHandler : public InternDataHandler {
public:
explicit LayerDataHandler(std::shared_ptr<Args> args);
void convert(const ParseResults& example, ParseResults& rslts) const override;
void getWordExamples(int idx, std::vector<ParseResults>& rslts) const override;
void loadFromFile(const std::string& file,
std::shared_ptr<DataParser> parser) override;
void getRandomRHS(std::vector<Base>& results, bool trainWord = false)
const override;
void save(std::ostream& out) override;
private:
void insert(
std::vector<Base>& rslt,
const std::vector<Base>& ex,
float dropout = 0.0) const;
};
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "doc_parser.h"
#include "utils/normalize.h"
#include <string>
#include <vector>
#include <fstream>
#include <boost/algorithm/string.hpp>
using namespace std;
namespace starspace {
LayerDataParser::LayerDataParser(
shared_ptr<Dictionary> dict,
shared_ptr<Args> args)
: DataParser(dict, args) {};
bool LayerDataParser::parse(
string& s,
vector<Base>& feats,
const string& sep) {
// split each part into tokens
vector<string> tokens;
boost::split(tokens, s, boost::is_any_of(string(sep)));
for (auto token : tokens) {
string t = token;
float weight = 1.0;
if (args_->useWeight) {
std::size_t pos = token.find(":");
if (pos != std::string::npos) {
t = token.substr(0, pos);
weight = atof(token.substr(pos + 1).c_str());
}
}
if (args_->normalizeText) {
normalize_text(t);
}
int32_t wid = dict_->getId(t);
if (wid != -1) {
feats.push_back(make_pair(wid, weight));
}
}
if (args_->ngrams > 1) {
addNgrams(tokens, feats, args_->ngrams);
}
return feats.size() > 0;
}
bool LayerDataParser::parse(
string& line,
ParseResults& rslt,
const string& sep) {
vector<string> parts;
boost::split(parts, line, boost::is_any_of("\t"));
int start_idx = 0;
if (parts[0].find("__weight__") != std::string::npos) {
std::size_t pos = parts[0].find(":");
if (pos != std::string::npos) {
rslt.weight = atof(parts[0].substr(pos + 1).c_str());
}
start_idx = 1;
}
if (args_->trainMode == 0) {
// the first part is input features
parse(parts[start_idx], rslt.LHSTokens);
start_idx += 1;
}
for (int i = start_idx; i < parts.size(); i++) {
vector<Base> feats;
if (parse(parts[i], feats)) {
rslt.RHSFeatures.push_back(feats);
}
}
bool isValid;
if (args_->trainMode == 0) {
isValid = (rslt.LHSTokens.size() > 0) && (rslt.RHSFeatures.size() > 0);
} else {
// need to have at least two examples
isValid = rslt.RHSFeatures.size() > 1;
}
return isValid;
}
} // namespace starspace
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
/**
* This is the parser class for the case where we have features
* to represent labels. It overrides a few key functions such as
* parse(input, output) and check(example) in the basic Parser class.
*/
#pragma once
#include "dict.h"
#include "parser.h"
#include <string>
#include <vector>
namespace starspace {
class LayerDataParser : public DataParser {
public:
LayerDataParser(
std::shared_ptr<Dictionary> dict,
std::shared_ptr<Args> args);
bool parse(
std::string& line,
std::vector<Base>& rslt,
const std::string& sep=" ");
bool parse(
std::string& line,
ParseResults& rslt,
const std::string& sep="\t") override;
};
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "starspace.h"
#include <iostream>
#include <boost/algorithm/string/predicate.hpp>
using namespace std;
using namespace starspace;
int main(int argc, char** argv) {
shared_ptr<Args> args = make_shared<Args>();
args->parseArgs(argc, argv);
args->printArgs();
StarSpace sp(args);
if (args->isTrain) {
if (!args->initModel.empty()) {
if (boost::algorithm::ends_with(args->initModel, ".tsv")) {
sp.initFromTsv(args->initModel);
} else {
sp.initFromSavedModel(args->initModel);
cout << "------Loaded model args:\n";
args->printArgs();
}
} else {
sp.init();
}
sp.train();
sp.saveModel(args->model);
sp.saveModelTsv(args->model + ".tsv");
} else {
if (boost::algorithm::ends_with(args->model, ".tsv")) {
sp.initFromTsv(args->model);
} else {
sp.initFromSavedModel(args->model);
cout << "------Loaded model args:\n";
args->printArgs();
}
sp.evaluate();
}
return 0;
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
/**
* Mostly a collection of convenience routines around ublas.
* We avoid doing any actual compute-intensive work in this file.
*/
#pragma once
#include <math.h>
#include <iostream>
#include <functional>
#include <random>
#include <thread>
#include <algorithm>
#include <vector>
#include <boost/numeric/ublas/matrix.hpp>
#include <boost/numeric/ublas/matrix_proxy.hpp>
#include <boost/numeric/ublas/io.hpp>
namespace starspace {
struct MatrixDims {
size_t r, c;
size_t numElts() const { return r * c; }
bool operator==(const MatrixDims& rhs) {
return r == rhs.r && c == rhs.c;
}
};
template<typename Real = float>
struct Matrix {
static const int kAlign = 64;
boost::numeric::ublas::matrix<Real> matrix;
explicit Matrix(MatrixDims dims,
Real sd = 1.0) :
matrix(dims.r, dims.c)
{
assert(matrix.size1() == dims.r);
assert(matrix.size2() == dims.c);
if (sd > 0.0) {
randomInit(sd);
}
}
explicit Matrix(const std::vector<std::vector<Real>>& init) {
size_t rows = init.size();
size_t maxCols = 0;
for (const auto& r : init) {
maxCols = std::max(maxCols, r.size());
}
alloc(rows, maxCols);
for (size_t i = 0; i < numRows(); i++) {
size_t j;
for (j = 0; j < init[i].size(); j++) {
(*this)[i][j] = init[i][j];
}
for (; j < numCols(); j++) {
(*this)[i][j] = 0.0;
}
}
}
explicit Matrix(std::istream& in) {
in >> matrix;
}
Matrix() {
alloc(0, 0);
}
Real* operator[](size_t i) {
assert(i >= 0);
assert(i < numRows());
return &matrix(i, 0);
}
const Real* operator[](size_t i) const {
assert(i >= 0);
assert(i < numRows());
return &matrix(i, 0);
}
Real& cell(size_t i, size_t j) {
assert(i >= 0);
assert(i < numCols());
assert(j < numCols());
assert(j >= 0);
return matrix(i, j);
}
void add(const Matrix<Real>& rhs, Real scale = 1.0) {
matrix += scale * rhs.matrix;
}
void forEachCell(std::function<void(Real&)> l) {
for (size_t i = 0; i < numRows(); i++)
for (size_t j = 0; j < numCols(); j++)
l(matrix(i, j));
}
void forEachCell(std::function<void(Real)> l) const {
for (size_t i = 0; i < numRows(); i++)
for (size_t j = 0; j < numCols(); j++)
l(matrix(i, j));
}
void forEachCell(std::function<void(Real&, size_t, size_t)> l) {
for (size_t i = 0; i < numRows(); i++)
for (size_t j = 0; j < numCols(); j++)
l(matrix(i, j), i, j);
}
void forEachCell(std::function<void(Real, size_t, size_t)> l) const {
for (size_t i = 0; i < numRows(); i++)
for (size_t j = 0; j < numCols(); j++)
l(matrix(i, j), i, j);
}
void sanityCheck() const {
#ifndef NDEBUG
forEachCell([&](Real r, size_t i, size_t j) {
assert(!std::isnan(r));
assert(!std::isinf(r));
});
#endif
}
void forRow(size_t r, std::function<void(Real&, size_t)> l) {
for (size_t j = 0; j < numCols(); j++) l(matrix(r, j), j);
}
void forRow(size_t r, std::function<void(Real, size_t)> l) const {
for (size_t j = 0; j < numCols(); j++) l(matrix(r, j), j);
}
void forCol(size_t c, std::function<void(Real&, size_t)> l) {
for (size_t i = 0; i < numRows(); i++) l(matrix(i, c), i);
}
void forCol(size_t c, std::function<void(Real, size_t)> l) const {
for (size_t i = 0; i < numRows(); i++) l(matrix(c, i), i);
}
static void mul(const Matrix& l, const Matrix& r, Matrix& dest) {
dest.matrix = boost::numeric::ublas::prod(l.matrix, r.matrix);
}
void updateRow(size_t r, Matrix& addend, Real scale = 1.0) {
using namespace boost::numeric::ublas;
assert(addend.numRows() == 1);
assert(addend.numCols() == numCols());
row(r) += Row { addend.matrix, 0 } * scale;
}
typedef boost::numeric::ublas::matrix_row<boost::numeric::ublas::matrix<Real>>
Row;
Row row(size_t r) { return Row{ matrix, r }; }
/* implicit */ operator Row() {
assert(numRows() == 1);
return Row{ matrix, 0 };
}
size_t numElts() const { return numRows() * numCols(); }
size_t numRows() const { return matrix.size1(); }
size_t numCols() const { return matrix.size2(); }
MatrixDims getDims() const { return { numRows(), numCols() }; }
void reshape(MatrixDims dims) {
if (dims == getDims()) return;
alloc(dims.r, dims.c);
}
typedef size_t iterator;
iterator begin() { return 0; }
iterator end() { return numElts(); }
void write(std::ostream& out) {
out << matrix;
}
void randomInit(Real sd = 1.0) {
if (numElts() > 0) {
// Multi-threaded initialization brings debug init time down
// from minutes to seconds.
auto d = &matrix(0, 0);
std::minstd_rand gen;
auto nd = std::normal_distribution<Real>(0, sd);
for (size_t i = 0; i < numElts(); i++) {
d[i] = nd(gen);
};
}
}
private:
void alloc(size_t r, size_t c) {
matrix = boost::numeric::ublas::matrix<Real>(r, c);
}
};
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "model.h"
#include <boost/archive/binary_iarchive.hpp>
#include <boost/archive/binary_oarchive.hpp>
#include <boost/archive/text_iarchive.hpp>
#include <boost/archive/text_oarchive.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/lexical_cast.hpp>
#include <thread>
#include <algorithm>
#include <chrono>
#include <iostream>
#include <fstream>
#include <sstream>
namespace starspace {
using namespace std;
using namespace boost::numeric;
EmbedModel::EmbedModel(
shared_ptr<Args> args,
shared_ptr<Dictionary> dict) {
args_ = args;
dict_ = dict;
initModelWeights();
}
void EmbedModel::initModelWeights() {
assert(dict_ != nullptr);
size_t num_lhs = dict_->nwords() + dict_->nlabels();
if (args_->ngrams > 1) {
num_lhs += args_->bucket;
}
LHSEmbeddings_ =
std::shared_ptr<SparseLinear<Real>>(
new SparseLinear<Real>({num_lhs, args_->dim},args_->initRandSd)
);
if (args_->shareEmb) {
RHSEmbeddings_ = LHSEmbeddings_;
} else {
RHSEmbeddings_ =
std::shared_ptr<SparseLinear<Real>>(
new SparseLinear<Real>({num_lhs, args_->dim},args_->initRandSd)
);
}
if (args_->adagrad) {
LHSUpdates_.resize(LHSEmbeddings_->numRows());
RHSUpdates_.resize(RHSEmbeddings_->numRows());
}
if (args_->verbose) {
cout << "Initialized model weights. Model size :\n"
<< "matrix : " << LHSEmbeddings_->numRows() << ' '
<< LHSEmbeddings_->numCols() << endl;
}
}
Real dot(Matrix<Real>::Row a, Matrix<Real>::Row b) {
const auto dim = a.size();
assert(dim > 0);
assert(a.size() == b.size());
return ublas::inner_prod(a, b);
}
Real norm2(Matrix<Real>::Row a) {
auto retval = norm_2(a);
return std::max(std::numeric_limits<Real>::epsilon(), retval);
}
// consistent accessor methods for straight indices and index-weight pairs
int32_t index(int32_t idx) { return idx; }
int32_t index(std::pair<int32_t, Real> idxWeightPair) {
return idxWeightPair.first;
}
constexpr float weight(int32_t idx) { return 1.0; }
float weight(std::pair<int32_t, Real> idxWeightPair) {
return idxWeightPair.second;
}
Matrix<Real> EmbedModel::projectRHS(const std::vector<Base>& ws) {
Matrix<Real> retval;
projectRHS(ws, retval);
return retval;
}
Matrix<Real> EmbedModel::projectLHS(const std::vector<Base>& ws) {
Matrix<Real> retval;
projectLHS(ws, retval);
return retval;
}
void EmbedModel::projectLHS(const std::vector<Base>& ws, Matrix<Real>& retval) {
LHSEmbeddings_->forward(ws, retval);
if (ws.size()) {
auto norm = (args_->similarity == "dot") ?
pow(ws.size(), args_->p) : norm2(retval);
retval.matrix /= norm;
}
}
void EmbedModel::projectRHS(const std::vector<Base>& ws, Matrix<Real>& retval) {
RHSEmbeddings_->forward(ws, retval);
if (ws.size()) {
auto norm = (args_->similarity == "dot") ?
pow(ws.size(), args_->p) : norm2(retval);
retval.matrix /= norm;
}
}
Real EmbedModel::trainOneExample(
shared_ptr<InternDataHandler> data,
const ParseResults& s,
int negSearchLimit,
Real rate,
bool trainWord) {
if (s.RHSTokens.size() == 0 || s.LHSTokens.size() == 0) {
return 0.0;
}
if (args_->debug) {
auto printVec = [&](const vector<Base>& vec) {
cout << "vec : ";
for (auto v : vec) {cout << v.first << ':' << v.second << ' ';}
cout << endl;
};
printVec(s.LHSTokens);
printVec(s.RHSTokens);
cout << endl;
}
Real wRate = s.weight * rate;
if (args_->loss == "softmax") {
return trainNLL(
data,
s.LHSTokens, s.RHSTokens,
negSearchLimit, wRate,
trainWord
);
} else {
// default is hinge loss
return trainOne(
data,
s.LHSTokens, s.RHSTokens,
negSearchLimit, wRate,
trainWord
);
}
}
Real EmbedModel::train(shared_ptr<InternDataHandler> data,
int numThreads,
std::chrono::time_point<std::chrono::high_resolution_clock> t_start,
int epochs_done,
Real rate,
Real finishRate,
bool verbose) {
assert(rate >= finishRate);
assert(rate >= 0.0);
// Use a layer of indirection when accessing the corpus to allow shuffling.
auto numSamples = data->getSize();
vector<int> indices(numSamples);
{
int i = 0;
for (auto& idx: indices) idx = i++;
}
std::random_shuffle(indices.begin(), indices.end());
// If we decrement after *every* sample, precision causes us to lose the
// update.
const int kDecrStep = 1000;
auto decrPerKSample = (rate - finishRate) / (numSamples / kDecrStep);
const Real negSearchLimit = std::min(numSamples,
size_t(args_->negSearchLimit));
numThreads = std::max(numThreads, 2);
numThreads -= 1; // Withold one thread for the norm thread.
numThreads = std::min(numThreads, int(numSamples));
vector<Real> losses(numThreads);
vector<long> counts(numThreads);
auto trainThread = [&](int idx,
vector<int>::const_iterator start,
vector<int>::const_iterator end) {
assert(start >= indices.begin());
assert(end >= start);
assert(end <= indices.end());
bool amMaster = idx == 0;
int64_t elapsed;
auto t_epoch_start = std::chrono::high_resolution_clock::now();
losses[idx] = 0.0;
counts[idx] = 0;
for (auto ip = start; ip < end; ip++) {
auto i = *ip;
float thisLoss = 0.0;
if (args_->trainMode == 5 || args_->trainWord) {
vector<ParseResults> exs;
data->getWordExamples(i, exs);
for (auto ex : exs) {
thisLoss = trainOneExample(data, ex, negSearchLimit, rate, true);
assert(thisLoss >= 0.0);
counts[idx]++;
losses[idx] += thisLoss;
}
}
if (args_->trainMode != 5) {
ParseResults ex;
data->getExampleById(i, ex);
thisLoss = trainOneExample(data, ex, negSearchLimit, rate, false);
assert(thisLoss >= 0.0);
counts[idx]++;
losses[idx] += thisLoss;
}
// update rate racily.
if ((i % kDecrStep) == (kDecrStep - 1)) {
rate -= decrPerKSample;
}
if (amMaster && ((ip - indices.begin()) % 100 == 99 || (ip + 1) == end)) {
auto t_end = std::chrono::high_resolution_clock::now();
auto t_epoch_spent =
std::chrono::duration<double>(t_end-t_epoch_start).count();
double ex_done_this_epoch = ip - indices.begin();
int ex_left = ((end - start) * (args_->epoch - epochs_done))
- ex_done_this_epoch;
double ex_done = epochs_done * (end - start) + ex_done_this_epoch;
double time_per_ex = double(t_epoch_spent) / ex_done_this_epoch;
int eta = int(time_per_ex * double(ex_left));
auto tot_spent = std::chrono::duration<double>(t_end-t_start).count();
if (tot_spent > args_->maxTrainTime) {
break;
}
double epoch_progress = ex_done_this_epoch / (end - start);
double progress = ex_done / (ex_done + ex_left);
if (eta > args_->maxTrainTime - tot_spent) {
eta = args_->maxTrainTime - tot_spent;
progress = tot_spent / (eta + tot_spent);
}
int etah = eta / 3600;
int etam = (eta - etah * 3600) / 60;
int etas = (eta - etah * 3600 - etam * 60);
int toth = int(tot_spent) / 3600;
int totm = (tot_spent - toth * 3600) / 60;
int tots = (tot_spent - toth * 3600 - totm * 60);
std::cerr << std::fixed;
std::cerr << "\rEpoch: " << std::setprecision(1) << 100 * epoch_progress << "%";
std::cerr << " lr: " << std::setprecision(6) << rate;
std::cerr << " loss: " << std::setprecision(6) << losses[idx] / counts[idx];
if (eta < 60) {
std::cerr << " eta: <1min ";
} else {
std::cerr << " eta: " << std::setprecision(3) << etah << "h" << etam << "m";
}
std::cerr << " tot: " << std::setprecision(3) << toth << "h" << totm << "m" << tots << "s ";
std::cerr << " (" << std::setprecision(1) << 100 * progress << "%)";
std::cerr << std::flush;
}
}
};
vector<thread> threads;
bool doneTraining = false;
size_t numPerThread = ceil(numSamples / numThreads);
assert(numPerThread > 0);
for (size_t i = 0; i < numThreads; i++) {
auto start = i * numPerThread;
auto end = std::min(start + numPerThread, numSamples);
assert(end >= start);
assert(end <= numSamples);
auto b = indices.begin() + start;
auto e = indices.begin() + end;
assert(b >= indices.begin());
assert(e >= b);
assert(e <= indices.end());
threads.emplace_back(thread([=] {
trainThread(i, b, e);
}));
}
// .. and a norm truncation thread. It's not worth it to slow
// down every update with truncation, so just work our way through
// truncating as needed on a separate thread.
std::thread truncator([&] {
auto trunc = [](Matrix<Real>::Row row, double maxNorm) {
auto norm = norm2(row);
if (norm > maxNorm) {
row *= (maxNorm / norm);
}
};
for (int i = 0; !doneTraining; i++) {
auto wIdx = i % LHSEmbeddings_->numRows();
trunc(LHSEmbeddings_->row(wIdx), args_->norm);
}
});
for (auto& t: threads) t.join();
// All done. Shut the truncator down.
doneTraining = true;
truncator.join();
Real totLoss = std::accumulate(losses.begin(), losses.end(), 0.0);
long totCount = std::accumulate(counts.begin(), counts.end(), 0);
return totLoss / totCount;
}
void EmbedModel::normalize(Matrix<float>::Row row, double maxNorm) {
auto norm = norm2(row);
if (norm != maxNorm) { // Not all of them are updated.
if (norm == 0.0) { // Unlikely!
norm = 0.01;
}
row *= (maxNorm / norm);
}
}
float EmbedModel::trainOne(shared_ptr<InternDataHandler> data,
const vector<Base>& items,
const vector<Base>& labels,
size_t negSearchLimit,
Real rate0,
bool trainWord) {
if (items.size() == 0) return 0.0; // nothing to learn.
using namespace boost::numeric::ublas;
// Keep all the activations on the stack so we can asynchronously
// update.
Matrix<Real> lhs, rhsP, rhsN;
projectLHS(items, lhs);
check(lhs);
auto cols = lhs.numCols();
projectRHS(labels, rhsP);
check(rhsP);
const auto posSim = similarity(lhs, rhsP);
Real negSim = std::numeric_limits<Real>::min();
// Some simple helpers to characterize the current triple we're
// considering.
auto tripleLoss = [&] (Real posSim, Real negSim) {
auto val = args_->margin - posSim + negSim;
assert(!isnan(posSim));
assert(!isnan(negSim));
assert(!isinf(posSim));
assert(!isinf(negSim));
// We want the max representable loss to have some wiggle room to
// compute with.
const auto kMaxLoss = 10e7;
auto retval = std::max(std::min(val, kMaxLoss), 0.0);
return retval;
};
// Select negative examples
Real loss = 0.0;
std::vector<Matrix<Real>> negs;
std::vector<std::vector<Base>> negLabelsBatch;
Matrix<Real> negMean;
negMean.matrix = zero_matrix<Real>(1, cols);
for (int i = 0; i < negSearchLimit &&
negs.size() < args_->maxNegSamples; i++) {
std::vector<Base> negLabels;
do {
data->getRandomRHS(negLabels, trainWord);
} while (negLabels == labels);
projectRHS(negLabels, rhsN);
check(rhsN);
auto thisLoss = tripleLoss(posSim, similarity(lhs, rhsN));
if (thisLoss > 0.0) {
loss += thisLoss;
negs.emplace_back(rhsN);
negLabelsBatch.emplace_back(negLabels);
negMean.add(rhsN);
assert(loss >= 0.0);
}
}
loss /= negSearchLimit;
negMean.matrix /= negs.size();
// Couldn't find a negative example given reasonable effort, so
// give up.
if (negs.size() == 0) return 0.0;
assert(!std::isinf(loss));
if (rate0 == 0.0) return loss;
// Let w be the average of the input features, t+ be the positive
// example and t- be the average of the negative examples.
// Our error E is:
//
// E = k - dot(w, t+) + dot(w, t-)
//
// Differentiating term-by-term we get:
//
// dE / dw = t- - t+
// dE / dt- = w
// dE / dt+ = -w
//
// This is the innermost loop, so cache misses count. Please do some perf
// testing if you end up modifying it.
// gradW = \sum_i t_i- - t+. We're done with negMean, so reuse it.
auto gradW = negMean;
gradW.add(rhsP, -1);
auto nRate = rate0 / negs.size();
std::vector<Real> negRate(negs.size());
std::fill(negRate.begin(), negRate.end(), nRate);
backward(items, labels, negLabelsBatch,
gradW, lhs,
rate0, -rate0, negRate);
return loss;
}
float EmbedModel::trainNLL(shared_ptr<InternDataHandler> data,
const vector<Base>& items,
const vector<Base>& labels,
int32_t negSearchLimit,
Real rate0,
bool trainWord) {
if (items.size() == 0) return 0.0; // nothing to learn.
Matrix<Real> lhs, rhsP, rhsN;
using namespace boost::numeric::ublas;
projectLHS(items, lhs);
check(lhs);
projectRHS(labels, rhsP);
check(rhsP);
// label is treated as class 0
auto numClass = args_->negSearchLimit + 1;
std::vector<Real> prob(numClass);
std::vector<Matrix<Real>> negClassVec;
std::vector<std::vector<Base>> negLabelsBatch;
prob[0] = dot(lhs, rhsP);
Real max = prob[0];
for (int i = 1; i < numClass; i++) {
std::vector<Base> negLabels;
do {
data->getRandomRHS(negLabels, trainWord);
} while (negLabels == labels);
projectRHS(negLabels, rhsN);
check(rhsN);
negClassVec.push_back(rhsN);
negLabelsBatch.push_back(negLabels);
prob[i] = dot(lhs, rhsN);
max = std::max(prob[i], max);
}
Real base = 0;
for (int i = 0; i < numClass; i++) {
prob[i] = exp(prob[i] - max);
base += prob[i];
}
// normalize the probabilities
for (int i = 0; i < numClass; i++) { prob[i] /= base; };
Real loss = - log(prob[0]);
// Let w be the average of the words in the post, t+ be the
// positive example (the tag the post has) and t- be the average
// of the negative examples (the tags we searched for with submarginal
// separation above).
// Our error E is:
//
// E = - log P(t+)
//
// Where P(t) = exp(dot(w, t)) / (\sum_{t'} exp(dot(w, t')))
//
// Differentiating term-by-term we get:
//
// dE / dw = t+ (P(t+) - 1)
// dE / dt+ = w (P(t+) - 1)
// dE / dt- = w P(t-)
auto gradW = rhsP;
gradW.matrix *= (prob[0] - 1);
for (int i = 0; i < numClass - 1; i++) {
gradW.add(negClassVec[i], prob[i + 1]);
}
std::vector<Real> negRate(numClass - 1);
for (int i = 0; i < negRate.size(); i++) {
negRate[i] = prob[i + 1] * rate0;
}
backward(items, labels, negLabelsBatch,
gradW, lhs,
rate0, (prob[0] - 1 ) * rate0, negRate);
return loss;
}
void EmbedModel::backward(
const vector<Base>& items,
const vector<Base>& labels,
const vector<vector<Base>>& negLabels,
Matrix<Real>& gradW,
Matrix<Real>& lhs,
Real rate_lhs,
Real rate_rhsP,
const vector<Real>& rate_rhsN) {
using namespace boost::numeric::ublas;
auto cols = lhs.numCols();
typedef
std::function<void(MatrixRow&, const MatrixRow&, Real, Real, std::vector<Real>&, int32_t)>
UpdateFn;
auto updatePlain = [&] (MatrixRow& dest, const MatrixRow& src,
Real rate,
Real weight,
std::vector<Real>& adagradWeight,
int32_t idx) {
dest -= (rate * src);
};
auto updateAdagrad = [&] (MatrixRow& dest, const MatrixRow& src,
Real rate,
Real weight,
std::vector<Real>& adagradWeight,
int32_t idx) {
assert(idx < adagradWeight.size());
adagradWeight[idx] += weight / cols;
rate /= sqrt(adagradWeight[idx] + 1e-6);
updatePlain(dest, src, rate, weight, adagradWeight, idx);
};
auto update = args_->adagrad ?
UpdateFn(updateAdagrad) : UpdateFn(updatePlain);
Real n1 = 0, n2 = 0;
if (args_->adagrad) {
n1 = dot(gradW, gradW);
n2 = dot(lhs, lhs);
}
// Update input items.
for (auto w : items) {
auto row = LHSEmbeddings_->row(index(w));
update(row, gradW, rate_lhs * weight(w), n1, LHSUpdates_, index(w));
}
// Update positive example.
for (auto la : labels) {
auto row = RHSEmbeddings_->row(index(la));
update(row, lhs, rate_rhsP * weight(la), n2, RHSUpdates_, index(la));
}
// Update negative example.
for (size_t i = 0; i < negLabels.size(); i++) {
for (auto la : negLabels[i]) {
auto row = RHSEmbeddings_->row(index(la));
update(row, lhs, rate_rhsN[i] * weight(la), n2, RHSUpdates_, index(la));
}
}
}
Real EmbedModel::similarity(const MatrixRow& a, const MatrixRow& b) {
auto retval = (args_->similarity == "dot") ? dot(a, b) : cosine(a, b);
assert(!isnan(retval));
assert(!isinf(retval));
return retval;
}
Real EmbedModel::cosine(const MatrixRow& a, const MatrixRow& b) {
auto normA = dot(a, a), normB = dot(b, b);
if (normA == 0.0 || normB == 0.0) {
return 0.0;
}
return dot(a, b) / sqrt(normA * normB);
}
vector<pair<int32_t, Real>>
EmbedModel::kNN(shared_ptr<SparseLinear<Real>> lookup,
Matrix<Real> point,
int numSim) {
typedef pair<int32_t, Real> Cand;
int maxn = dict_->nwords() + dict_->nlabels();
vector<Cand> mostSimilar(std::min(numSim, maxn));
for (auto& s: mostSimilar) {
s = { -1, -1.0 };
}
auto resort = [&] {
std::sort(mostSimilar.begin(), mostSimilar.end(),
[&](Cand a, Cand b) { return a.second > b.second; });
};
Matrix<Real> contV;
for (int i = 0; i < maxn; i++) {
lookup->forward(i, contV);
Real sim = (args_->similarity == "dot") ?
dot(point, contV) : cosine(point, contV);
if (sim > mostSimilar.back().second) {
mostSimilar.back() = { i, sim };
resort();
}
}
for (auto r : mostSimilar) {
if (r.first == -1 || r.second == -1.0) {
abort();
}
}
return mostSimilar;
}
void EmbedModel::loadTsvLine(string& line, int lineNum,
int cols, const string sep) {
vector<string> pieces;
static const string zero = "0.0";
// Strip trailing spaces
while (line.size() && isspace(line[line.size() - 1])) {
line.resize(line.size() - 1);
}
boost::split(pieces, line, boost::is_any_of(sep));
if (pieces.size() > cols + 1) {
cout << "Hmm, truncating long (" << pieces.size() <<
") record at line " << lineNum;
if (true) {
for (size_t i = cols; i < pieces.size(); i++) {
cout << "Warning excess fields " << pieces[i]
<< "; misformatted file?";
}
}
pieces.resize(cols + 1);
}
if (pieces.size() == cols) {
cout << "Missing record at line " << lineNum <<
"; assuming empty string";
pieces.insert(pieces.begin(), "");
}
while (pieces.size() < cols + 1) {
cout << "Zero-padding short record at line " << lineNum;
pieces.push_back(zero);
}
auto idx = dict_->getId(pieces[0]);
if (idx == -1) {
if (pieces[0].size() > 0) {
cerr << "Failed to insert record: " << line << "\n";
}
return;
}
auto row = LHSEmbeddings_->row(idx);
for (int i = 0; i < cols; i++) {
row(i) = boost::lexical_cast<Real>(pieces[i + 1].c_str());
}
}
void EmbedModel::loadTsv(const char* fname, const string sep) {
cout << "Loading model from file " << fname << endl;
auto cols = args_->dim;
std::ifstream ifs(fname);
auto filelen = [&](ifstream& f) {
auto pos = f.tellg();
f.seekg(0, ios_base::end);
auto retval = f.tellg();
f.seekg(pos, ios_base::beg);
return retval;
};
auto len = filelen(ifs);
auto numThreads = sysconf(_SC_NPROCESSORS_ONLN);
vector<off_t> partitions(numThreads + 1);
partitions[0] = 0;
partitions[numThreads] = len;
string unused;
for (int i = 1; i < numThreads; i++) {
ifs.seekg((len / numThreads) * i);
getline(ifs, unused);
partitions[i] = ifs.tellg();
}
// It's possible that the ranges in partitions overlap; consider,
// e.g., a machine with 100 hardware threads and only 99 lines
// in the file. In this case, we'll do some excess work but loadTsvLine
// is idempotent, so it is ok.
std::vector<thread> threads;
for (int i = 0; i < numThreads; i++) {
auto body = [this, fname, cols, sep, i, &partitions]() {
// Get our own seek pointer.
ifstream ifs(fname);
ifs.seekg(partitions[i]);
string line;
while (ifs.tellg() < partitions[i + 1] && getline(ifs, line)) {
// We don't know the line number. Super-bummer.
loadTsvLine(line, -1, cols, sep);
}
};
threads.emplace_back(body);
}
for (auto& t: threads) {
t.join();
}
cout << "Model loaded.\n";
}
void EmbedModel::loadTsv(istream& in, const string sep) {
auto cols = LHSEmbeddings_->numCols();
assert(RHSEmbeddings_->numCols() == cols);
string line;
int lineNum = 0;
while (getline(in, line)) {
lineNum++;
loadTsvLine(line, lineNum, cols, sep);
}
}
void EmbedModel::saveTsv(ostream& out, const char sep) const {
auto dumpOne = [&](shared_ptr<SparseLinear<Real>> emb) {
auto size = dict_->nwords() + dict_->nlabels();
for (size_t i = 0; i < size; i++) {
// Skip invalid IDs.
string symbol = dict_->getSymbol(i);
out << symbol;
emb->forRow(i,
[&](Real r, size_t j) {
out << sep << r;
});
out << "\n";
}
};
dumpOne(LHSEmbeddings_);
}
void EmbedModel::save(ostream& out) const {
LHSEmbeddings_->write(out);
if (!args_->shareEmb) {
RHSEmbeddings_->write(out);
}
}
void EmbedModel::load(ifstream& in) {
LHSEmbeddings_.reset(new SparseLinear<Real>(in));
if (args_->shareEmb) {
RHSEmbeddings_ = LHSEmbeddings_;
} else {
RHSEmbeddings_.reset(new SparseLinear<Real>(in));
}
}
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#pragma once
#include "matrix.h"
#include "proj.h"
#include "dict.h"
#include "utils/normalize.h"
#include "utils/args.h"
#include "data.h"
#include "doc_data.h"
#include <fstream>
#include <boost/noncopyable.hpp>
#include <vector>
namespace starspace {
typedef float Real;
typedef boost::numeric::ublas::matrix_row<typeof(Matrix<Real>::matrix)>
MatrixRow;
typedef boost::numeric::ublas::vector<Real> Vector;
/*
* The model is basically two lookup tables: one for left hand side
* (LHS) entities, one for right hand side (RHS) entities.
*/
struct EmbedModel : public boost::noncopyable {
public:
explicit EmbedModel(std::shared_ptr<Args> args,
std::shared_ptr<Dictionary> dict);
typedef std::vector<ParseResults> Corpus;
float train(std::shared_ptr<InternDataHandler> data,
int numThreads,
std::chrono::time_point<std::chrono::high_resolution_clock> t_start,
int epochs_done,
Real startRate,
Real endRate,
bool verbose = true);
float test(std::shared_ptr<InternDataHandler> data, int numThreads) {
return this->train(data, numThreads,
std::chrono::high_resolution_clock::now(), 0,
0.0, 0.0, false);
}
float trainOneExample(
std::shared_ptr<InternDataHandler> data,
const ParseResults& s,
int negSearchLimit,
Real rate,
bool trainWord = false);
float trainOne(std::shared_ptr<InternDataHandler> data,
const std::vector<Base>& items,
const std::vector<Base>& labels,
size_t maxNegSamples,
Real rate,
bool trainWord = false);
float trainNLL(std::shared_ptr<InternDataHandler> data,
const std::vector<Base>& items,
const std::vector<Base>& labels,
int32_t negSearchLimit,
Real rate,
bool trainWord = false);
void backward(const std::vector<Base>& items,
const std::vector<Base>& labels,
const std::vector<std::vector<Base>>& negLabels,
Matrix<Real>& gradW,
Matrix<Real>& lhs,
Real rate_lhs,
Real rate_rhsP,
const std::vector<Real>& rate_rhsN);
// Querying
std::vector<std::pair<int32_t, Real>>
kNN(std::shared_ptr<SparseLinear<Real>> lookup,
Matrix<Real> point,
int numSim);
std::vector<std::pair<int32_t, Real>>
findLHSLike(Matrix<Real> point, int numSim = 5) {
return kNN(LHSEmbeddings_, point, numSim);
}
std::vector<std::pair<int32_t, Real>>
findRHSLike(Matrix<Real> point, int numSim = 5) {
return kNN(RHSEmbeddings_, point, numSim);
}
Matrix<Real> projectRHS(const std::vector<Base>& ws);
Matrix<Real> projectLHS(const std::vector<Base>& ws);
void projectLHS(const std::vector<Base>& ws, Matrix<Real>& retval);
void projectRHS(const std::vector<Base>& ws, Matrix<Real>& retval);
void loadTsv(std::istream& in, const std::string sep = "\t ");
void loadTsv(const char* fname, const std::string sep = "\t ");
void loadTsv(const std::string& fname, const std::string sep = "\t ") {
return loadTsv(fname.c_str(), sep);
}
void saveTsv(std::ostream& out, const char sep = '\t') const;
void save(std::ostream& out) const;
void load(std::ifstream& in);
const std::string& lookupLHS(int32_t idx) const {
return dict_->getSymbol(idx);
}
const std::string& lookupRHS(int32_t idx) const {
return dict_->getLabel(idx);
}
void loadTsvLine(std::string& line, int lineNum, int cols,
const std::string sep = "\t");
std::shared_ptr<Dictionary> getDict() { return dict_; }
std::shared_ptr<SparseLinear<Real>>& getLHSEmbeddings() {
return LHSEmbeddings_;
}
const std::shared_ptr<SparseLinear<Real>>& getLHSEmbeddings() const {
return LHSEmbeddings_;
}
std::shared_ptr<SparseLinear<Real>>& getRHSEmbeddings() {
return RHSEmbeddings_;
}
const std::shared_ptr<SparseLinear<Real>>& getRHSEmbeddings() const {
return RHSEmbeddings_;
}
void initModelWeights();
Real similarity(const MatrixRow& a, const MatrixRow& b);
Real similarity(Matrix<Real>& a, Matrix<Real>& b) {
return similarity(asRow(a), asRow(b));
}
static Real cosine(const MatrixRow& a, const MatrixRow& b);
static Real cosine(Matrix<Real>& a, Matrix<Real>& b) {
return cosine(asRow(a), asRow(b));
}
static MatrixRow asRow(Matrix<Real>& m) {
assert(m.numRows() == 1);
return MatrixRow(m.matrix, 0);
}
static void normalize(Matrix<Real>::Row row, double maxNorm = 1.0);
static void normalize(Matrix<Real>& m) { normalize(asRow(m)); }
private:
std::shared_ptr<Dictionary> dict_;
std::shared_ptr<SparseLinear<Real>> LHSEmbeddings_;
std::shared_ptr<SparseLinear<Real>> RHSEmbeddings_;
std::shared_ptr<Args> args_;
std::vector<Real> LHSUpdates_;
std::vector<Real> RHSUpdates_;
#ifdef NDEBUG
static const bool debug = false;
#else
static const bool debug = false;
#endif
static void check(const Matrix<Real>& m) {
m.sanityCheck();
}
static void check(const boost::numeric::ublas::matrix<Real>& m) {
if (!debug) return;
for (int i = 0; i < m.size1(); i++) {
for (int j = 0; j < m.size2(); j++) {
assert(!std::isnan(m(i, j)));
assert(!std::isinf(m(i, j)));
}
}
}
};
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "parser.h"
#include "utils/normalize.h"
#include <string>
#include <vector>
#include <fstream>
#include <iostream>
#include <boost/algorithm/string.hpp>
using namespace std;
namespace starspace {
void chomp(std::string& line, char toChomp = '\n') {
auto sz = line.size();
if (sz >= 1 && line[sz - 1] == toChomp) {
line.resize(sz - 1);
}
}
DataParser::DataParser(
shared_ptr<Dictionary> dict,
shared_ptr<Args> args) {
dict_ = dict;
args_ = args;
}
bool DataParser::parse(
std::string& s,
ParseResults& rslts,
const string& sep) {
chomp(s);
vector<string> toks;
boost::split(toks, s, boost::is_any_of(string(sep)));
return parse(toks, rslts);
}
void DataParser::parseForDict(
string& line,
vector<string>& tokens,
const string& sep) {
chomp(line);
vector<string> toks;
boost::split(toks, line, boost::is_any_of(string(sep)));
for (int i = 0; i < toks.size(); i++) {
string token = toks[i];
if (args_->useWeight) {
std::size_t pos = toks[i].find(":");
if (pos != std::string::npos) {
token = toks[i].substr(0, pos);
}
}
if (args_->normalizeText) {
normalize_text(token);
}
if (token.find("__weight__") == std::string::npos) {
tokens.push_back(token);
}
}
}
// check wether it is a valid example
bool DataParser::check(const ParseResults& example) {
if (args_->trainMode == 0) {
// require lhs and rhs
return !example.RHSTokens.empty() && !example.LHSTokens.empty();
} if (args_->trainMode == 5) {
// only requires lhs.
return !example.LHSTokens.empty();
} else {
// lhs is not required, but rhs should contain at least 2 example
return example.RHSTokens.size() > 1;
}
}
void DataParser::addNgrams(
const std::vector<std::string>& tokens,
std::vector<Base>& line,
int n) {
vector<int32_t> hashes;
for (auto token: tokens) {
entry_type type = dict_->getType(token);
if (type == entry_type::word) {
hashes.push_back(dict_->hash(token));
}
}
for (int32_t i = 0; i < hashes.size(); i++) {
uint64_t h = hashes[i];
for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) {
h = h * Dictionary::HASH_C + hashes[j];
int64_t id = h % args_->bucket;
line.push_back(make_pair(dict_->nwords() + dict_->nlabels() + id, 1.0));
}
}
}
bool DataParser::parse(
const std::vector<std::string>& tokens,
ParseResults& rslts) {
for (auto &token: tokens) {
if (token.find("__weight__") != std::string::npos) {
std::size_t pos = token.find(":");
if (pos != std::string::npos) {
rslts.weight = atof(token.substr(pos + 1).c_str());
}
continue;
}
string t = token;
float weight = 1.0;
if (args_->useWeight) {
std::size_t pos = token.find(":");
if (pos != std::string::npos) {
t = token.substr(0, pos);
weight = atof(token.substr(pos + 1).c_str());
}
}
if (args_->normalizeText) {
normalize_text(t);
}
int32_t wid = dict_->getId(t);
if (wid < 0) {
continue;
}
entry_type type = dict_->getType(wid);
if (type == entry_type::word) {
rslts.LHSTokens.push_back(make_pair(wid, weight));
}
if (type == entry_type::label) {
rslts.RHSTokens.push_back(make_pair(wid, weight));
}
}
if (args_->ngrams > 1) {
addNgrams(tokens, rslts.LHSTokens, args_->ngrams);
}
return check(rslts);
}
bool DataParser::parse(
const std::vector<std::string>& tokens,
vector<Base>& rslts) {
for (auto &token: tokens) {
auto t = token;
float weight = 1.0;
if (args_->useWeight) {
std::size_t pos = token.find(":");
if (pos != std::string::npos) {
t = token.substr(0, pos);
weight = atof(token.substr(pos + 1).c_str());
}
}
if (args_->normalizeText) {
normalize_text(t);
}
int32_t wid = dict_->getId(t);
if (wid < 0) {
continue;
}
entry_type type = dict_->getType(wid);
if (type == entry_type::word) {
rslts.push_back(make_pair(wid, weight));
}
}
if (args_->ngrams > 1) {
addNgrams(tokens, rslts, args_->ngrams);
}
return rslts.size() > 0;
}
} // namespace starspace
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
/**
* This is the basic class of data parsing.
* It provides essential functions as follows:
* - parse(input, output):
* takes input as a line of string (or a vector of string tokens)
* and return output result which is one example contains l.h.s. features
* and r.h.s. features.
*
* - parseForDict(input, tokens):
* takes input as a line of string, output tokens to be added for building
* the dictionary.
*
* - check(example):
* checks whether the example is a valid example.
*
* - addNgrams(input, output):
* add ngrams from input as output.
*
* One can write different parsers for data with different format.
*/
#pragma once
#include "dict.h"
#include <string>
#include <vector>
namespace starspace {
typedef std::pair<int32_t, float> Base;
struct ParseResults {
float weight = 1.0;
std::vector<Base> LHSTokens;
std::vector<Base> RHSTokens;
std::vector<std::vector<Base>> RHSFeatures;
};
typedef std::vector<ParseResults> Corpus;
class DataParser {
public:
explicit DataParser(
std::shared_ptr<Dictionary> dict,
std::shared_ptr<Args> args);
virtual bool parse(
std::string& s,
ParseResults& rslt,
const std::string& sep="\t ");
virtual void parseForDict(
std::string& s,
std::vector<std::string>& tokens,
const std::string& sep="\t ");
bool parse(
const std::vector<std::string>& tokens,
std::vector<Base>& rslt);
bool parse(
const std::vector<std::string>& tokens,
ParseResults& rslt);
bool check(const ParseResults& example);
void addNgrams(
const std::vector<std::string>& tokens,
std::vector<Base>& line,
int32_t n);
std::shared_ptr<Dictionary> getDict() { return dict_; };
void resetDict(std::shared_ptr<Dictionary> dict) { dict_ = dict; };
protected:
std::shared_ptr<Dictionary> dict_;
std::shared_ptr<Args> args_;
};
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "proj.h"
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
// The SparseLinear class implements the lookup tables used in starspace model.
#pragma once
#include "matrix.h"
#include <stdlib.h>
#include <stdio.h>
#include <vector>
#include <assert.h>
#include <string.h>
#include <fstream>
namespace starspace {
template<typename Real = float>
struct SparseLinear : public Matrix<Real> {
explicit SparseLinear(MatrixDims dims,
Real sd = 1.0) : Matrix<Real>(dims, sd) { }
explicit SparseLinear(std::ifstream& in) : Matrix<Real>(in) { }
void forward(int in, Matrix<Real>& mout) {
using namespace boost::numeric::ublas;
const auto c = this->numCols();
mout.matrix.resize(1, c);
memcpy(&mout[0][0], &(*this)[in][0], c * sizeof(Real));
}
void forward(const std::vector<int>& in, Matrix<Real>& mout) {
using namespace boost::numeric::ublas;
const auto c = this->numCols();
mout.matrix = zero_matrix<Real>(1, c);
auto outRow = mout.row(0);
for (const auto& elt: in) {
assert(elt < this->numRows());
outRow += this->row(elt);
}
}
void forward(const std::vector<std::pair<int, Real>>& in,
Matrix<Real> &mout) {
using namespace boost::numeric::ublas;
const auto c = this->numCols();
mout.matrix = zero_matrix<Real>(1, c);
auto outRow = mout.row(0);
for (const auto& pair: in) {
assert(pair.first < this->numRows());
outRow += this->row(pair.first) * pair.second;
}
}
void backward(const std::vector<int>& in,
const Matrix<Real>& mb, const Real alpha) {
// Just update this racily and in-place.
assert(mb.numRows() == 1);
auto b = mb[0];
for (const auto& elt: in) {
auto row = (*this)[elt];
for (int i = 0; i < this->numCols(); i++) {
row[i] -= alpha * b[i];
}
}
}
Real* allocOutput() {
Real* retval;
auto val = posix_memalign((void**)&retval, Matrix<Real>::kAlign,
this->numCols() * sizeof(Real));
if (val != 0) {
perror("could not allocate output");
throw this;
}
return retval;
}
};
}
#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::export]]
List rcpp_hello() {
CharacterVector x = CharacterVector::create("foo", "bar");
NumericVector y = NumericVector::create(0.0, 1.0);
List z = List::create(x, y);
return z;
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "starspace.h"
#include <iostream>
#include <queue>
#include <unordered_set>
#include <boost/algorithm/string.hpp>
using namespace std;
namespace starspace {
StarSpace::StarSpace(shared_ptr<Args> args)
: args_(args)
, dict_(nullptr)
, parser_(nullptr)
, trainData_(nullptr)
, validData_(nullptr)
, testData_(nullptr)
, model_(nullptr)
{}
void StarSpace::initParser() {
if (args_->fileFormat == "fastText") {
parser_ = make_shared<DataParser>(dict_, args_);
} else if (args_->fileFormat == "labelDoc") {
parser_ = make_shared<LayerDataParser>(dict_, args_);
} else {
cerr << "Unsupported file format. Currently support: fastText or labelDoc.\n";
exit(EXIT_FAILURE);
}
}
void StarSpace::initDataHandler() {
if (args_->isTrain) {
trainData_ = initData();
trainData_->loadFromFile(args_->trainFile, parser_);
// set validation data
if (!args_->validationFile.empty()) {
validData_ = initData();
validData_->loadFromFile(args_->validationFile, parser_);
}
} else {
if (args_->testFile != "") {
testData_ = initData();
testData_->loadFromFile(args_->testFile, parser_);
}
}
}
shared_ptr<InternDataHandler> StarSpace::initData() {
if (args_->fileFormat == "fastText") {
return make_shared<InternDataHandler>(args_);
} else if (args_->fileFormat == "labelDoc") {
return make_shared<LayerDataHandler>(args_);
} else {
cerr << "Unsupported file format. Currently support: fastText or labelDoc.\n";
exit(EXIT_FAILURE);
}
return nullptr;
}
// initialize dict and load data
void StarSpace::init() {
cout << "Start to initialize starspace model.\n";
assert(args_ != nullptr);
// build dict
initParser();
dict_ = make_shared<Dictionary>(args_);
auto filename = args_->trainFile;
dict_->readFromFile(filename, parser_);
parser_->resetDict(dict_);
if (args_->debug) {dict_->save(cout);}
// init train data class
trainData_ = initData();
trainData_->loadFromFile(args_->trainFile, parser_);
// init model with args and dict
model_ = make_shared<EmbedModel>(args_, dict_);
// set validation data
if (!args_->validationFile.empty()) {
validData_ = initData();
validData_->loadFromFile(args_->validationFile, parser_);
}
}
void StarSpace::initFromSavedModel(const string& filename) {
cout << "Start to load a trained starspace model.\n";
std::ifstream in(filename, std::ifstream::binary);
if (!in.is_open()) {
std::cerr << "Model file cannot be opened for loading!" << std::endl;
exit(EXIT_FAILURE);
}
string magic;
char c;
while ((c = in.get()) != 0) {
magic.push_back(c);
}
cout << magic << endl;
if (magic != kMagic) {
std::cerr << "Magic signature does not match!" << std::endl;
exit(EXIT_FAILURE);
}
// load args
args_->load(in);
// init and load dict
dict_ = make_shared<Dictionary>(args_);
dict_->load(in);
// init and load model
model_ = make_shared<EmbedModel>(args_, dict_);
model_->load(in);
cout << "Model loaded.\n";
// init data parser
initParser();
initDataHandler();
}
void StarSpace::initFromTsv(const string& filename) {
cout << "Start to load a trained embedding model in tsv format.\n";
assert(args_ != nullptr);
ifstream in(filename);
if (!in.is_open()) {
std::cerr << "Model file cannot be opened for loading!" << std::endl;
exit(EXIT_FAILURE);
}
// Test dimension of first line, adjust args appropriately
// (This is also so we can load a TSV file without even specifying the dim.)
string line;
getline(in, line);
vector<string> pieces;
boost::split(pieces, line, boost::is_any_of("\t "));
int dim = pieces.size() - 1;
if (args_->dim != dim) {
args_->dim = dim;
cout << "Setting dim from Tsv file to: " << dim << endl;
}
in.close();
// build dict
dict_ = make_shared<Dictionary>(args_);
dict_->loadDictFromModel(filename);
if (args_->debug) {dict_->save(cout);}
// load Model
model_ = make_shared<EmbedModel>(args_, dict_);
model_->loadTsv(filename, "\t ");
// init data parser
initParser();
initDataHandler();
}
void StarSpace::train() {
float rate = args_->lr;
float decrPerEpoch = (rate - 1e-9) / args_->epoch;
auto t_start = std::chrono::high_resolution_clock::now();
for (int i = 0; i < args_->epoch; i++) {
if (args_->saveEveryEpoch && i > 0) {
auto filename = args_->model;
if (args_->saveTempModel) {
filename = filename + "_epoch" + std::to_string(i);
}
saveModel(filename);
saveModelTsv(filename + ".tsv");
}
cout << "Training epoch " << i << ": " << rate << ' ' << decrPerEpoch << endl;
auto err = model_->train(trainData_, args_->thread,
t_start, i,
rate, rate - decrPerEpoch);
printf("\n ---+++ %20s %4d Train error : %3.8f +++--- %c%c%c\n",
"Epoch", i, err,
0xe2, 0x98, 0x83);
if (validData_ != nullptr) {
auto valid_err = model_->test(validData_, args_->thread);
cout << "Validation error: " << valid_err << endl;
}
rate -= decrPerEpoch;
auto t_end = std::chrono::high_resolution_clock::now();
auto tot_spent = std::chrono::duration<double>(t_end-t_start).count();
if (tot_spent >args_->maxTrainTime) {
cout << "MaxTrainTime exceeded." << endl;
break;
}
}
}
void StarSpace::parseDoc(
const string& line,
vector<Base>& ids,
const string& sep) {
vector<string> tokens;
boost::split(tokens, line, boost::is_any_of(string(sep)));
parser_->parse(tokens, ids);
}
Matrix<Real> StarSpace::getDocVector(const string& line, const string& sep) {
vector<Base> ids;
parseDoc(line, ids, sep);
return model_->projectLHS(ids);
}
MatrixRow StarSpace::getNgramVector(const string& phrase) {
vector<string> tokens;
boost::split(tokens, phrase, boost::is_any_of(string(" ")));
if (tokens.size() > args_->ngrams) {
std::cerr << "Error! Input ngrams size is greater than model ngrams size.\n";
exit(EXIT_FAILURE);
}
if (tokens.size() == 1) {
// looking up the entity embedding directly
auto id = dict_->getId(tokens[0]);
if (id != -1) {
return model_->getLHSEmbeddings()->row(id);
}
}
uint64_t h = 0;
for (auto token: tokens) {
if (dict_->getType(token) == entry_type::word) {
h = h * Dictionary::HASH_C + dict_->hash(token);
}
}
int64_t id = h % args_->bucket;
return model_->getLHSEmbeddings()->row(id + dict_->nwords() + dict_->nlabels());
}
void StarSpace::nearestNeighbor(const string& line, int k) {
auto vec = getDocVector(line, " ");
auto preds = model_->findLHSLike(vec, k);
for (auto n : preds) {
cout << dict_->getSymbol(n.first) << ' ' << n.second << endl;
}
}
void StarSpace::loadBaseDocs() {
if (args_->basedoc.empty()) {
if (args_->fileFormat == "labelDoc") {
std::cerr << "Must provide base labels when label is featured.\n";
exit(EXIT_FAILURE);
}
for (int i = 0; i < dict_->nlabels(); i++) {
baseDocs_.push_back({ make_pair(i + dict_->nwords(), 1.0) });
baseDocVectors_.push_back(
model_->projectRHS({ make_pair(i + dict_->nwords(), 1.0) })
);
}
cout << "Predictions use " << dict_->nlabels() << " known labels." << endl;
} else {
cout << "Loading base docs from file : " << args_->basedoc << endl;
ifstream fin(args_->basedoc);
if (!fin.is_open()) {
std::cerr << "Base doc file cannot be opened for loading!" << std::endl;
exit(EXIT_FAILURE);
}
string line;
while (getline(fin, line)) {
vector<Base> ids;
parseDoc(line, ids, "\t ");
baseDocs_.push_back(ids);
auto docVec = model_->projectRHS(ids);
baseDocVectors_.push_back(docVec);
}
fin.close();
if (baseDocVectors_.size() == 0) {
std::cerr << "ERROR: basedoc file '" << args_->basedoc << "' is empty." << std::endl;
exit(EXIT_FAILURE);
}
cout << "Finished loading " << baseDocVectors_.size() << " base docs.\n";
}
}
void StarSpace::predictOne(
const vector<Base>& input,
vector<Predictions>& pred) {
auto lhsM = model_->projectLHS(input);
std::priority_queue<Predictions> heap;
for (int i = 0; i < baseDocVectors_.size(); i++) {
auto cur_score = model_->similarity(lhsM, baseDocVectors_[i]);
heap.push({ cur_score, i });
}
// get the first K predictions
int i = 0;
while (i < args_->K && heap.size() > 0) {
pred.push_back(heap.top());
heap.pop();
i++;
}
}
Metrics StarSpace::evaluateOne(
const vector<Base>& lhs,
const vector<Base>& rhs,
vector<Predictions>& pred) {
std::priority_queue<Predictions> heap;
auto lhsM = model_->projectLHS(lhs);
auto rhsM = model_->projectRHS(rhs);
// Our evaluation function currently assumes there is only one correct label.
// TODO: generalize this to the multilabel case.
auto score = model_->similarity(lhsM, rhsM);
int rank = 1;
heap.push({ score, 0 });
for (int i = 0; i < baseDocVectors_.size(); i++) {
// in the case basedoc labels are not provided, all labels become basedoc,
// and we skip the correct label for comparison.
if ((args_->basedoc.empty()) && (i == rhs[0].first - dict_->nwords())) {
continue;
}
auto cur_score = model_->similarity(lhsM, baseDocVectors_[i]);
if (cur_score > score) {
rank++;
} else if (cur_score == score) {
float flip = (float) rand() / RAND_MAX;
if (flip > 0.5) {
rank++;
}
}
heap.push({ cur_score, i + 1 });
}
// get the first K predictions
int i = 0;
while (i < args_->K && heap.size() > 0) {
pred.push_back(heap.top());
heap.pop();
i++;
}
Metrics s;
s.clear();
s.update(rank);
return s;
}
void StarSpace::printDoc(ostream& ofs, const vector<Base>& tokens) {
for (auto t : tokens) {
// skip ngram tokens
if (t.first < dict_->size()) {
ofs << dict_->getSymbol(t.first) << ' ';
}
}
ofs << endl;
}
void StarSpace::evaluate() {
// check that it is not in trainMode 5
if (args_->trainMode == 5) {
std::cerr << "Test is undefined in trainMode 5. Please use other trainMode for testing.\n";
exit(EXIT_FAILURE);
}
// set dropout probability to 0 in test case
args_->dropoutLHS = 0.0;
args_->dropoutRHS = 0.0;
loadBaseDocs();
int N = testData_->getSize();
auto numThreads = args_->thread;
vector<thread> threads;
vector<Metrics> metrics(numThreads);
vector<vector<Predictions>> predictions(N);
int numPerThread = ceil((float) N / numThreads);
assert(numPerThread > 0);
vector<ParseResults> examples;
testData_->getNextKExamples(N, examples);
auto evalThread = [&] (int idx, int start, int end) {
metrics[idx].clear();
for (int i = start; i < end; i++) {
auto s = evaluateOne(examples[i].LHSTokens, examples[i].RHSTokens, predictions[i]);
metrics[idx].add(s);
}
};
for (int i = 0; i < numThreads; i++) {
auto start = std::min(i * numPerThread, N);
auto end = std::min(start + numPerThread, N);
assert(end >= start);
threads.emplace_back(thread([=] {
evalThread(i, start, end);
}));
}
for (auto& t : threads) t.join();
Metrics result;
result.clear();
for (int i = 0; i < numThreads; i++) {
if (args_->debug) { metrics[i].print(); }
result.add(metrics[i]);
}
result.average();
result.print();
if (!args_->predictionFile.empty()) {
// print out prediction results to file
ofstream ofs(args_->predictionFile);
for (int i = 0; i < N; i++) {
ofs << "Example " << i << ":\nLHS:\n";
printDoc(ofs, examples[i].LHSTokens);
ofs << "RHS: \n";
printDoc(ofs, examples[i].RHSTokens);
ofs << "Predictions: \n";
for (auto pred : predictions[i]) {
if (pred.second == 0) {
ofs << "(++) [" << pred.first << "]\t";
printDoc(ofs, examples[i].RHSTokens);
} else {
ofs << "(--) [" << pred.first << "]\t";
printDoc(ofs, baseDocs_[pred.second - 1]);
}
}
ofs << "\n";
}
ofs.close();
}
}
void StarSpace::saveModel(const string& filename) {
cout << "Saving model to file : " << filename << endl;
std::ofstream ofs(filename, std::ofstream::binary);
if (!ofs.is_open()) {
std::cerr << "Model file cannot be opened for saving!" << std::endl;
exit(EXIT_FAILURE);
}
// sign model
ofs.write(kMagic.data(), kMagic.size() * sizeof(char));
ofs.put(0);
args_->save(ofs);
dict_->save(ofs);
model_->save(ofs);
ofs.close();
}
void StarSpace::saveModelTsv(const string& filename) {
cout << "Saving model in tsv format : " << filename << endl;
ofstream fout(filename);
model_->saveTsv(fout, '\t');
fout.close();
}
} // starspace
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#pragma once
#include "utils/args.h"
#include "dict.h"
#include "matrix.h"
#include "parser.h"
#include "doc_parser.h"
#include "model.h"
#include "utils/utils.h"
namespace starspace {
typedef std::pair<Real, int32_t> Predictions;
class StarSpace {
public:
explicit StarSpace(std::shared_ptr<Args> args);
void init();
void initFromTsv(const std::string& filename);
void initFromSavedModel(const std::string& filename);
void train();
void evaluate();
MatrixRow getNgramVector(const std::string& phrase);
Matrix<Real> getDocVector(
const std::string& line,
const std::string& sep = " \t");
void parseDoc(
const std::string& line,
std::vector<Base>& ids,
const std::string& sep);
void nearestNeighbor(const std::string& line, int k);
void saveModel(const std::string& filename);
void saveModelTsv(const std::string& filename);
void printDoc(std::ostream& ofs, const std::vector<Base>& tokens);
const std::string kMagic = "STARSPACE-2017-2";
void loadBaseDocs();
void predictOne(
const std::vector<Base>& input,
std::vector<Predictions>& pred);
std::shared_ptr<Args> args_;
std::vector<std::vector<Base>> baseDocs_;
private:
void initParser();
void initDataHandler();
std::shared_ptr<InternDataHandler> initData();
Metrics evaluateOne(
const std::vector<Base>& lhs,
const std::vector<Base>& rhs,
std::vector<Predictions>& pred);
std::shared_ptr<Dictionary> dict_;
std::shared_ptr<DataParser> parser_;
std::shared_ptr<InternDataHandler> trainData_;
std::shared_ptr<InternDataHandler> validData_;
std::shared_ptr<InternDataHandler> testData_;
std::shared_ptr<EmbedModel> model_;
std::vector<Matrix<Real>> baseDocVectors_;
};
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include <gtest/gtest.h>
#include "../matrix.h"
using namespace starspace;
TEST(Matrix, init) {
srand(12);
Matrix<float> mtx {
{ { 0.01, 2.23, 3.34 },
{ 1.11, -0.4, 0.2 } } };
EXPECT_EQ(mtx.numCols(), 3);
EXPECT_EQ(mtx.numRows(), 2);
float tot = 0.0;
mtx.forRow(1, [&](float& f, int c) {
ASSERT_TRUE(c == 0 || c == 1 || c == 2);
if (c == 0) EXPECT_FLOAT_EQ(f, 1.11);
if (c == 1) EXPECT_FLOAT_EQ(f, -0.4);
if (c == 2) EXPECT_FLOAT_EQ(f, 0.2);
});
mtx.forCol(2, [&](float& f, int r) {
ASSERT_TRUE(r == 0 || r == 1);
if (r == 0) EXPECT_FLOAT_EQ(f, 3.34);
if (r == 1) EXPECT_FLOAT_EQ(f, 0.2);
});
}
TEST(Matrix, mulI) {
Matrix<float> I4 {
{ { 1.0, 0.0, 0.0, 0.0, },
{ 0.0, 1.0, 0.0, 0.0, },
{ 0.0, 0.0, 1.0, 0.0, },
{ 0.0, 0.0, 0.0, 1.0 } } };
for (int i = 0; i < 22; i++) {
size_t otherDim = 1 + rand() % 17;
Matrix<float> l({otherDim, 4});
Matrix<float> result({otherDim, 4});
Matrix<float>::mul(l, I4, result);
result.forEachCell([&](float& f, int i, int j) {
// EXPECT_FLOAT_EQ(result[i][j], l[i][j]);
});
}
}
TEST(Matrix, mulRand) {
Matrix<double> A {
{ { -0.2, 0.3, 0.4 },
{ 0.2, 0.2, -0.001 },
{ 0.3, 0.5, 1 },
{ 1, 2, 3 },
{ -2, -1, 0 },
{ 0.3, 0.5, 1 },
{ 7, -0.01, -7 } } };
Matrix<double> B {
{ { 1, 2, 3, 4 },
{ -2, -1, 0, 1 },
{ 0.01, 10, 0.3, 2} } };
Matrix<double> C;
Matrix<double> expectedC {
{ { -0.796, 3.3, -0.48, 0.3 },
{ -0.20001, 0.19, 0.5997, 0.998 },
{ -0.69, 10.1, 1.2, 3.7 },
{ -2.97, 30.0, 3.9, 12.0 },
{ 0.0, -3.0, -6.0, -9.0 },
{ -0.69, 10.1, 1.2, 3.7 },
{ 6.95, -55.99, 18.9, 13.99 } } };
Matrix<double>::mul(A, B, C);
C.forEachCell([&](double d, int i, int j) {
EXPECT_FLOAT_EQ(expectedC[i][j], d);
});
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "../proj.h"
#include <gtest/gtest.h>
using namespace std;
using namespace starspace;
TEST(Proj, forward) {
SparseLinear<float> sl({5, 1});
vector<int> inputs = { 1 ,
4 };
Matrix<float> output;
sl.forward(inputs, output);
EXPECT_FLOAT_EQ(output[0][0], sl[1][0] + sl[4][0]);
}
TEST(Proj, weightedForward) {
SparseLinear<float> sl({5, 1});
vector<pair<int,float>> inputs = { {1, 0.5} ,
{4, 1.5} };
Matrix<float> output;
sl.forward(inputs, output);
EXPECT_FLOAT_EQ(output[0][0], sl[1][0] * 0.5 + sl[4][0] * 1.5);
}
TEST(Proj, empty) {
SparseLinear<float> sl({5, 1});
vector<int> inputs = { };
Matrix<float> output;
sl.forward(inputs, output);
output.forEachCell([&](float& f, int i, int j) {
EXPECT_EQ(f, 0.0);
});
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "args.h"
#include <iostream>
#include <algorithm>
#include <string>
#include <cstring>
#include <assert.h>
using namespace std;
namespace starspace {
Args::Args() {
lr = 0.01;
termLr = 1e-9;
norm = 1.0;
margin = 0.05;
wordWeight = 0.5;
initRandSd = 0.001;
dropoutLHS = 0.0;
dropoutRHS = 0.0;
p = 0.5;
dim = 10;
epoch = 5;
ws = 5;
maxTrainTime = 60*60*24*100;
thread = 10;
maxNegSamples = 10;
negSearchLimit = 50;
minCount = 1;
minCountLabel = 1;
K = 5;
verbose = false;
debug = false;
adagrad = true;
normalizeText = false;
trainMode = 0;
fileFormat = "fastText";
label = "__label__";
bucket = 2000000;
ngrams = 1;
loss = "hinge";
similarity = "cosine";
isTrain = false;
shareEmb = true;
saveEveryEpoch = false;
saveTempModel = false;
useWeight = false;
trainWord = false;
}
bool Args::isTrue(string arg) {
std::transform(arg.begin(), arg.end(), arg.begin(),
[&](char c) { return tolower(c); }
);
return (arg == "true" || arg == "1");
}
void Args::parseArgs(int argc, char** argv) {
if (argc <= 1) {
cerr << "Usage: need to specify whether it is train or test.\n";
printHelp();
exit(EXIT_FAILURE);
}
if (strcmp(argv[1], "train") == 0) {
isTrain = true;
} else if (strcmp(argv[1], "test") == 0) {
isTrain = false;
} else if (strcmp(argv[1], "-h") == 0 || strcmp(argv[1], "-help") == 0) {
std::cerr << "Here is the help! Usage:" << std::endl;
printHelp();
exit(EXIT_FAILURE);
} else {
cerr << "Usage: the first argument should be either train or test.\n";
printHelp();
exit(EXIT_FAILURE);
}
int i = 2;
while (i < argc) {
if (argv[i][0] != '-') {
cout << "Provided argument without a dash! Usage:" << endl;
printHelp();
exit(EXIT_FAILURE);
}
// handling "--"
if (strlen(argv[i]) >= 2 && argv[i][1] == '-') {
argv[i] = argv[i] + 1;
}
if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "-help") == 0) {
std::cerr << "Here is the help! Usage:" << std::endl;
printHelp();
exit(EXIT_FAILURE);
} else if (strcmp(argv[i], "-trainFile") == 0) {
trainFile = string(argv[i + 1]);
} else if (strcmp(argv[i], "-validationFile") == 0) {
validationFile = string(argv[i + 1]);
} else if (strcmp(argv[i], "-testFile") == 0) {
testFile = string(argv[i + 1]);
} else if (strcmp(argv[i], "-predictionFile") == 0) {
predictionFile = string(argv[i + 1]);
} else if (strcmp(argv[i], "-basedoc") == 0) {
basedoc = string(argv[i + 1]);
} else if (strcmp(argv[i], "-model") == 0) {
model = string(argv[i + 1]);
} else if (strcmp(argv[i], "-initModel") == 0) {
initModel = string(argv[i + 1]);
} else if (strcmp(argv[i], "-fileFormat") == 0) {
fileFormat = string(argv[i + 1]);
} else if (strcmp(argv[i], "-label") == 0) {
label = string(argv[i + 1]);
} else if (strcmp(argv[i], "-loss") == 0) {
loss = string(argv[i + 1]);
} else if (strcmp(argv[i], "-similarity") == 0) {
similarity = string(argv[i + 1]);
} else if (strcmp(argv[i], "-lr") == 0) {
lr = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-p") == 0) {
p = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-termLr") == 0) {
termLr = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-norm") == 0) {
norm = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-margin") == 0) {
margin = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-initRandSd") == 0) {
initRandSd = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-dropoutLHS") == 0) {
dropoutLHS = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-dropoutRHS") == 0) {
dropoutRHS = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-wordWeight") == 0) {
wordWeight = atof(argv[i + 1]);
} else if (strcmp(argv[i], "-dim") == 0) {
dim = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-epoch") == 0) {
epoch = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-ws") == 0) {
ws = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-maxTrainTime") == 0) {
maxTrainTime = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-thread") == 0) {
thread = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-maxNegSamples") == 0) {
maxNegSamples = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-negSearchLimit") == 0) {
negSearchLimit = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-minCount") == 0) {
minCount = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-minCountLabel") == 0) {
minCountLabel = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-bucket") == 0) {
bucket = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-ngrams") == 0) {
ngrams = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-K") == 0) {
K = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-trainMode") == 0) {
trainMode = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "-verbose") == 0) {
verbose = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-debug") == 0) {
debug = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-adagrad") == 0) {
adagrad = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-shareEmb") == 0) {
shareEmb = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-normalizeText") == 0) {
normalizeText = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-saveEveryEpoch") == 0) {
saveEveryEpoch = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-saveTempModel") == 0) {
saveTempModel = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-useWeight") == 0) {
useWeight = isTrue(string(argv[i + 1]));
} else if (strcmp(argv[i], "-trainWord") == 0) {
trainWord = isTrue(string(argv[i + 1]));
} else {
cerr << "Unknown argument: " << argv[i] << std::endl;
printHelp();
exit(EXIT_FAILURE);
}
i += 2;
}
if (isTrain) {
if (trainFile.empty() || model.empty()) {
cerr << "Empty train file or output model path." << endl;
printHelp();
exit(EXIT_FAILURE);
}
} else {
if (testFile.empty() || model.empty()) {
cerr << "Empty test file or model path." << endl;
printHelp();
exit(EXIT_FAILURE);
}
}
// check for trainMode
if ((trainMode < 0) || (trainMode > 5)) {
cerr << "Uknown trainMode. We currently support the follow train modes:\n";
cerr << "trainMode 0: at training time, one label from RHS is picked as true label; LHS is the same from input.\n";
cerr << "trainMode 1: at training time, one label from RHS is picked as true label; LHS is the bag of the rest RHS labels.\n";
cerr << "trainMode 2: at training time, one label from RHS is picked as LHS; the bag of the rest RHS labels becomes the true label.\n";
cerr << "trainMode 3: at training time, one label from RHS is picked as true label and another label from RHS is picked as LHS.\n";
cerr << "trainMode 4: at training time, the first label from RHS is picked as LHS and the second one picked as true label.\n";
cerr << "trainMode 5: continuous bag of words training.\n";
exit(EXIT_FAILURE);
}
// check for loss type
if (!(loss == "hinge" || loss == "softmax")) {
cerr << "Unsupported loss type: " << loss << endl;
exit(EXIT_FAILURE);
}
// check for similarity type
if (!(similarity == "cosine" || similarity == "dot")) {
cerr << "Unsupported similarity type. Should be either dot or cosine.\n";
exit(EXIT_FAILURE);
}
// check for file format
if (!(fileFormat == "fastText" || fileFormat == "labelDoc")) {
cerr << "Unsupported file format type. Should be either fastText or labelDoc.\n";
exit(EXIT_FAILURE);
}
}
void Args::printHelp() {
cout << "\n"
<< "\"starspace train ...\" or \"starspace test ...\"\n\n"
<< "The following arguments are mandatory for train: \n"
<< " -trainFile training file path\n"
<< " -model output model file path\n\n"
<< "The following arguments are mandatory for test: \n"
<< " -testFile test file path\n"
<< " -model model file path\n\n"
<< "The following arguments for the dictionary are optional:\n"
<< " -minCount minimal number of word occurences [" << minCount << "]\n"
<< " -minCountLabel minimal number of label occurences [" << minCountLabel << "]\n"
<< " -ngrams max length of word ngram [" << ngrams << "]\n"
<< " -bucket number of buckets [" << bucket << "]\n"
<< " -label labels prefix [" << label << "]\n"
<< "\nThe following arguments for training are optional:\n"
<< " -initModel if not empty, it loads a previously trained model in -initModel and carry on training.\n"
<< " -trainMode takes value in [0, 1, 2, 3, 4, 5], see Training Mode Section. [" << trainMode << "]\n"
<< " -fileFormat currently support 'fastText' and 'labelDoc', see File Format Section. [" << fileFormat << "]\n"
<< " -saveEveryEpoch save intermediate models after each epoch [" << saveEveryEpoch << "]\n"
<< " -saveTempModel save intermediate models after each epoch with an unique name including epoch number [" << saveTempModel << "]\n"
<< " -lr learning rate [" << lr << "]\n"
<< " -dim size of embedding vectors [" << dim << "]\n"
<< " -epoch number of epochs [" << epoch << "]\n"
<< " -maxTrainTime max train time (secs) [" << maxTrainTime << "]\n"
<< " -negSearchLimit number of negatives sampled [" << negSearchLimit << "]\n"
<< " -maxNegSamples max number of negatives in a batch update [" << maxNegSamples << "]\n"
<< " -loss loss function {hinge, softmax} [hinge]\n"
<< " -margin margin parameter in hinge loss. It's only effective if hinge loss is used. [" << margin << "]\n"
<< " -similarity takes value in [cosine, dot]. Whether to use cosine or dot product as similarity function in hinge loss.\n"
<< " It's only effective if hinge loss is used. [" << similarity << "]\n"
<< " -adagrad whether to use adagrad in training [" << adagrad << "]\n"
<< " -shareEmb whether to use the same embedding matrix for LHS and RHS. [" << shareEmb << "]\n"
<< " -ws only used in trainMode 5, the size of the context window for word level training. [" << ws << "]\n"
<< " -dropoutLHS dropout probability for LHS features. [" << dropoutLHS << "]\n"
<< " -dropoutRHS dropout probability for RHS features. [" << dropoutRHS << "]\n"
<< " -initRandSd initial values of embeddings are randomly generated from normal distribution with mean=0, standard deviation=initRandSd. [" << initRandSd << "]\n"
<< " -trainWord whether to train word level together with other tasks (for multi-tasking). [" << trainWord << "]\n"
<< " -wordWeight if trainWord is true, wordWeight specifies example weight for word level training examples. [" << wordWeight << "]\n"
<< "\nThe following arguments for test are optional:\n"
<< " -basedoc file path for a set of labels to compare against true label. It is required when -fileFormat='labelDoc'.\n"
<< " In the case -fileFormat='fastText' and -basedoc is not provided, we compare true label with all other labels in the dictionary.\n"
<< " -predictionFile file path for save predictions. If not empty, top K predictions for each example will be saved.\n"
<< " -K if -predictionFile is not empty, top K predictions for each example will be saved.\n"
<< "\nThe following arguments are optional:\n"
<< " -normalizeText whether to run basic text preprocess for input files [" << normalizeText << "]\n"
<< " -useWeight whether input file contains weights [" << useWeight << "]\n"
<< " -verbose verbosity level [" << verbose << "]\n"
<< " -debug whether it's in debug mode [" << debug << "]\n"
<< " -thread number of threads [" << thread << "]\n"
<< std::endl;
}
void Args::printArgs() {
cout << "Arguments: \n"
<< "lr: " << lr << endl
<< "dim: " << dim << endl
<< "epoch: " << epoch << endl
<< "maxTrainTime: " << maxTrainTime << endl
<< "saveEveryEpoch: " << saveEveryEpoch << endl
<< "loss: " << loss << endl
<< "margin: " << margin << endl
<< "similarity: " << similarity << endl
<< "maxNegSamples: " << maxNegSamples << endl
<< "negSearchLimit: " << negSearchLimit << endl
<< "thread: " << thread << endl
<< "minCount: " << minCount << endl
<< "minCountLabel: " << minCountLabel << endl
<< "label: " << label << endl
<< "ngrams: " << ngrams << endl
<< "bucket: " << bucket << endl
<< "adagrad: " << adagrad << endl
<< "trainMode: " << trainMode << endl
<< "fileFormat: " << fileFormat << endl
<< "normalizeText: " << normalizeText << endl
<< "dropoutLHS: " << dropoutLHS << endl
<< "dropoutRHS: " << dropoutRHS << endl;
}
void Args::save(std::ostream& out) {
out.write((char*) &(dim), sizeof(int));
out.write((char*) &(epoch), sizeof(int));
// out.write((char*) &(maxTrainTime), sizeof(int));
out.write((char*) &(minCount), sizeof(int));
out.write((char*) &(minCountLabel), sizeof(int));
out.write((char*) &(maxNegSamples), sizeof(int));
out.write((char*) &(negSearchLimit), sizeof(int));
out.write((char*) &(ngrams), sizeof(int));
out.write((char*) &(bucket), sizeof(int));
out.write((char*) &(trainMode), sizeof(int));
out.write((char*) &(shareEmb), sizeof(bool));
out.write((char*) &(useWeight), sizeof(bool));
size_t size = fileFormat.size();
out.write((char*) &(size), sizeof(size_t));
out.write((char*) &(fileFormat[0]), size);
size = similarity.size();
out.write((char*) &(size), sizeof(size_t));
out.write((char*) &(similarity[0]), size);
}
void Args::load(std::istream& in) {
in.read((char*) &(dim), sizeof(int));
in.read((char*) &(epoch), sizeof(int));
// in.read((char*) &(maxTrainTime), sizeof(int));
in.read((char*) &(minCount), sizeof(int));
in.read((char*) &(minCountLabel), sizeof(int));
in.read((char*) &(maxNegSamples), sizeof(int));
in.read((char*) &(negSearchLimit), sizeof(int));
in.read((char*) &(ngrams), sizeof(int));
in.read((char*) &(bucket), sizeof(int));
in.read((char*) &(trainMode), sizeof(int));
in.read((char*) &(shareEmb), sizeof(bool));
in.read((char*) &(useWeight), sizeof(bool));
size_t size;
in.read((char*) &(size), sizeof(size_t));
fileFormat.resize(size);
in.read((char*) &(fileFormat[0]), size);
in.read((char*) &(size), sizeof(size_t));
similarity.resize(size);
in.read((char*) &(similarity[0]), size);
}
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#pragma once
#include <iostream>
#include <string>
namespace starspace {
class Args {
public:
Args();
std::string trainFile;
std::string validationFile;
std::string testFile;
std::string predictionFile;
std::string model;
std::string initModel;
std::string fileFormat;
std::string label;
std::string basedoc;
std::string loss;
std::string similarity;
double lr;
double termLr;
double norm;
double margin;
double initRandSd;
double p;
double dropoutLHS;
double dropoutRHS;
double wordWeight;
size_t dim;
int epoch;
int ws;
int maxTrainTime;
int thread;
int maxNegSamples;
int negSearchLimit;
int minCount;
int minCountLabel;
int bucket;
int ngrams;
int trainMode;
int K;
bool verbose;
bool debug;
bool adagrad;
bool isTrain;
bool normalizeText;
bool saveEveryEpoch;
bool saveTempModel;
bool shareEmb;
bool useWeight;
bool trainWord;
void parseArgs(int, char**);
void printHelp();
void printArgs();
void save(std::ostream& out);
void load(std::istream& in);
bool isTrue(std::string arg);
};
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "normalize.h"
#include <algorithm>
#include <ctype.h>
#include <assert.h>
#include <string>
namespace starspace {
void normalize_text(std::string& str) {
/*
* We categorize longer strings into the following buckets:
*
* 1. All punctuation-and-numeric. Things in this bucket get
* their numbers flattened, to prevent combinatorial explosions.
* They might be specific numbers, prices, etc.
*
* 2. All letters: case-flattened.
*
* 3. Mixed letters and numbers: a product ID? Flatten case and leave
* numbers alone.
*
* The case-normalization is state-machine-driven.
*/
bool allNumeric = true;
bool containsDigits = false;
for (char c: str) {
assert(c); // don't shove binary data through this.
containsDigits |= isdigit(c);
if (!isascii(c)) {
allNumeric = false;
continue;
}
if (!isalpha(c)) continue;
allNumeric = false;
}
bool flattenCase = true;
bool flattenNum = allNumeric && containsDigits;
if (!flattenNum && !flattenCase) return;
std::transform(str.begin(), str.end(), str.begin(),
[&](char c) {
if (flattenNum && isdigit(c)) return '0';
if (isalpha(c)) return char(tolower(c));
return c;
});
}
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#pragma once
#include <string>
namespace starspace {
// In-place normalization of UTF-8 strings.
extern void normalize_text(std::string& buf);
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#include "utils.h"
namespace starspace {
namespace detail {
__thread int id;
}
}
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#pragma once
#include <iostream>
#include <thread>
#include <fstream>
#include <vector>
#include <string>
#include <algorithm>
namespace starspace {
struct Metrics {
float hit1, hit10, hit20, hit50, rank;
int32_t count;
void clear() {
hit1 = 0;
hit10 = 0;
hit20 = 0;
hit50 = 0;
rank = 0;
count = 0;
};
void add(const Metrics& b) {
hit1 += b.hit1;
hit10 += b.hit10;
hit20 += b.hit20;
hit50 += b.hit50;
rank += b.rank;
count += b.count;
};
void average() {
if (count == 0) {
return ;
}
hit1 /= count;
hit10 /= count;
hit20 /= count;
hit50 /= count;
rank /= count;
}
void print() {
std::cout << "Evaluation Metrics : \n"
<< "hit@1: " << hit1
<< " hit@10: " << hit10
<< " hit@20: " << hit20
<< " hit@50: " << hit50
<< " mean ranks : " << rank
<< " Total examples : " << count << "\n";
}
void update(int cur_rank) {
if (cur_rank == 1) { hit1++; }
if (cur_rank <= 10) { hit10++; }
if (cur_rank <= 20) { hit20++; }
if (cur_rank <= 50) { hit50++; }
rank += cur_rank;
count++;
}
};
namespace detail {
extern __thread int id;
}
namespace {
inline int getThreadID() {
return detail::id;
}
}
namespace {
template<typename Stream>
void reset(Stream& s, std::streampos pos) {
s.clear();
s.seekg(pos, std::ios_base::beg);
}
template<typename Stream>
std::streampos tellg(Stream& s) {
auto retval = s.tellg();
return retval;
}
}
// Apply a closure pointwise to every line of a file.
template<typename String=std::string,
typename Lambda>
void foreach_line(const String& fname,
Lambda f,
int numThreads = 1) {
using namespace std;
auto filelen = [&](ifstream& f) {
auto pos = tellg(f);
f.seekg(0, ios_base::end);
return tellg(f);
};
ifstream ifs(fname);
if (!ifs.good()) {
throw runtime_error(string("error opening ") + fname);
}
auto len = filelen(ifs);
// partitions[i],partitions[i+1] will be the bytewise boundaries for the i'th
// thread.
std::vector<off_t> partitions(numThreads + 1);
partitions[0] = 0;
partitions[numThreads] = len;
// Seek to bytewise partition boundaries, and read one line forward.
string unused;
for (int i = 1; i < numThreads; i++) {
reset(ifs, (len / numThreads) * i);
getline(ifs, unused);
partitions[i] = tellg(ifs);
}
// It's possible that the ranges in partitions overlap; consider,
// e.g., a machine with 100 hardware threads and only 99 lines
// in the file. In this case, we'll do some excess work, so we ask
// that f() be idempotent.
vector<thread> threads;
for (int i = 0; i < numThreads; i++) {
threads.emplace_back([i, f, &fname, &partitions] {
detail::id = i;
// Get our own seek pointer.
ifstream ifs2(fname);
ifs2.seekg(partitions[i]);
string line;
while (tellg(ifs2) < partitions[i + 1] && getline(ifs2, line)) {
// We don't know the line number. Super-bummer.
f(line);
}
});
}
for (auto &t: threads) {
t.join();
}
}
} // namespace
Version: 1.0
RestoreWorkspace: Default
SaveWorkspace: Default
AlwaysSaveHistory: Default
EnableCodeIndexing: Yes
UseSpacesForTab: Yes
NumSpacesForTab: 2
Encoding: UTF-8
RnwWeave: Sweave
LaTeX: pdfLaTeX
AutoAppendNewline: Yes
StripTrailingWhitespace: Yes
BuildType: Package
PackageInstallArgs: --no-multiarch --with-keep.source
PackageRoxygenize: rd,collate,namespace
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment