前幾天有網友回覆 四則運算解析器 那篇,回頭瞄了一下舊程式碼剛好讓我得到了一個靈感,所以寫了這篇「函數解析器」。
我過去曾用 C++03 實作過一些小型語言的編譯器、直譯器,使用都是比較傳統的方法,也就是設計一個 AST 節點的基礎類別,再特化出各種不同類型的 AST 節點。這個寫法非常的囉唆,許多程式碼都是為了滿足靜態型別語言的規範,而不是實現真正的功能,相較之下 python、javascript 之類的語言可以用精簡許多的程式碼完成同樣的事情。
我最初的想法是,既然 C++11 有了 std::function,那麼就不需要透過基礎類別來提供共同界面,接著我又想到了可以用 std::bind 來連結子節點。最後靈光乍現,只要有 lambda,根本就沒有必要用 std::bind。對了,這個程式使用了 std::regex,最好使用最新的編譯器,已知在 GCC 4.8 會有錯誤,至於 GCC 5 以後和近幾版的 clang 應該都沒問題。
這個函數解析器的特點:
- 可以處理運算子優現順序和括號。
- 支援 log、sin、sqrt 等數學函數。
- 接受多參數。
簡單介紹一下使用方法:
auto foo = parseFunction({"x", "y"}, "sqrt(x*x + y*y)");
cout << foo({3, 4}) << "\n";
// 5
auto bar = parseFunction({"r", "pi"}, "r * r * pi");
cout << bar({10, 3.14159}) << endl;
// 314.159
這個程式能透過 exceptions 回報某些語法錯誤,並指出發生位置,雖然仍然不完善。
auto ill = parseFunction({"x"}, "100sin(x)");
// 100sin(x)
// ^
auto ill = parseFunction({"x"}, "x+k");
// x+k
// ^
auto ill = parseFunction({"x", "y"}, "30*(x+y");
// 30*(x+y
// ^
我一度猶豫要不要加個正規的 lexer,這樣錯誤處理會更方便且更精確,只是這樣一來程式碼勢必會膨脹許多。考慮到這只是一篇發表在部落格的概念演示,我想還是點到為止。(2016-08-28: 改用 std::sregex_token_iterator 改寫,不過仍然稱不上是完善的 lexer,而且回報錯誤位置變得有點 tricky)
對於直譯器技術有興趣的網友,可以從這篇了解 AST traversal 直譯器的基本概念。至於更進階的 bytecode 直譯器(register or stack based)以後有空或許我會寫幾篇文章來介紹。
#include <algorithm>
#include <regex>
#include <map>
#include <vector>
#include <iostream>
#include <cmath>
using namespace std;
using ArgList = std::vector<double>;
using Func = function<double(const ArgList&)>;
using TokenIter = std::sregex_token_iterator;
using VarTable = map<string, int>;
const map<string, double(*)(double)> funcTable = {
{"log", log}, {"sqrt", sqrt}, {"sin", sin}, {"cos", cos}
};
Func parseExpr(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd);
Func parseElem(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd) {
if (it != tokEnd) {
string token = *it;
auto fn = funcTable.find(token);
if (fn != funcTable.end()) {
auto f = fn->second;
auto sub = parseElem(varTable, ++it, tokEnd);
return [f, sub] (const ArgList& a) { return f(sub(a)); };
}
auto var = varTable.find(token);
if (var != varTable.end()) {
auto i = var->second;
++it;
return [i] (const ArgList& a) { return a[i]; };
}
if (token == "(") {
auto sub = parseExpr(varTable, ++it, tokEnd);
if (it != tokEnd && *it == ")") {
++it;
return sub;
}
} else if (token == "-") {
auto sub = parseElem(varTable, ++it, tokEnd);
return [sub] (const ArgList& a) { return -sub(a); };
} else {
double c = stod(token);
++it;
return [c] (const ArgList& a) { return c; };
}
}
throw runtime_error("syntax error");
}
Func parseProd(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd) {
Func lhs = parseElem(varTable, it, tokEnd);
while (it != tokEnd && (*it == "*" || *it == "/")) {
string op = *it++;
Func rhs = parseElem(varTable, it, tokEnd);
lhs = (op == "*") ?
Func([lhs, rhs](const ArgList& a) { return lhs(a) * rhs(a); }) :
Func([lhs, rhs](const ArgList& a) { return lhs(a) / rhs(a); }) ;
}
return lhs;
}
Func parseExpr(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd) {
Func lhs = parseProd(varTable, it, tokEnd);
while (it != tokEnd && (*it == "+" || *it == "-")) {
string op = *it++;
Func rhs = parseProd(varTable, it, tokEnd);
lhs = (op == "+") ?
Func([lhs, rhs](const ArgList& a) { return lhs(a) + rhs(a); }) :
Func([lhs, rhs](const ArgList& a) { return lhs(a) - rhs(a); }) ;
}
return lhs;
}
Func parseFunction(const vector<string>& varList, string src) {
VarTable varTable;
for (auto i = 0; i < varList.size(); ++i)
varTable[varList[i]] = i;
src.erase(remove_if(src.begin(), src.end(), ::isspace), src.end());
src.push_back(' ');
regex rx("[[:alpha:]]+|[[:digit:].]+|[^[:alnum:].]");
auto it = TokenIter(src.begin(), src.end(), rx);
try {
auto expr = parseExpr(varTable, it, TokenIter());
if (it != TokenIter() && *it != ' ')
throw runtime_error("syntax error");
return expr;
} catch (exception& ex) {
auto pos = (it != TokenIter()) ? it->first - src.begin() : 0;
throw runtime_error(string("syntax error:\n ") + src + "\n"
+ string(pos + 4, ' ') + '^');
}
}
int main() {
auto f1 = parseFunction({"x", "y"}, "1.5 + sqrt(x*x + y*y) / 2");
auto f2 = [](double x, double y) { return 1.5 + sqrt(x*x + y*y) / 2; };
cout << f1({3, 4}) << "\t"
<< f2( 3, 4 ) << "\n";
cout << f1({3.51, 2.18}) << "\t"
<< f2( 3.51, 2.18 ) << "\n";
cout << f1({-12.34, 56.789}) << "\t"
<< f2( -12.34, 56.789 ) << "\n";
return 0;
}

太神了!!
後來試了一下發覺 bug 不少 然後發現可以用 sregex_token_iterator 做 lexer