前幾天有網友回覆 四則運算解析器 那篇,回頭瞄了一下舊程式碼剛好讓我得到了一個靈感,所以寫了這篇「函數解析器」。
我過去曾用 C++03 實作過一些小型語言的編譯器、直譯器,使用都是比較傳統的方法,也就是設計一個 AST 節點的基礎類別,再特化出各種不同類型的 AST 節點。這個寫法非常的囉唆,許多程式碼都是為了滿足靜態型別語言的規範,而不是實現真正的功能,相較之下 python、javascript 之類的語言可以用精簡許多的程式碼完成同樣的事情。
我最初的想法是,既然 C++11 有了 std::function,那麼就不需要透過基礎類別來提供共同界面,接著我又想到了可以用 std::bind 來連結子節點。最後靈光乍現,只要有 lambda,根本就沒有必要用 std::bind。對了,這個程式使用了 std::regex,最好使用最新的編譯器,已知在 GCC 4.8 會有錯誤,至於 GCC 5 以後和近幾版的 clang 應該都沒問題。
這個函數解析器的特點:
- 可以處理運算子優現順序和括號。
- 支援 log、sin、sqrt 等數學函數。
- 接受多參數。
簡單介紹一下使用方法:
auto foo = parseFunction({"x", "y"}, "sqrt(x*x + y*y)"); cout << foo({3, 4}) << "\n"; // 5 auto bar = parseFunction({"r", "pi"}, "r * r * pi"); cout << bar({10, 3.14159}) << endl; // 314.159
這個程式能透過 exceptions 回報某些語法錯誤,並指出發生位置,雖然仍然不完善。
auto ill = parseFunction({"x"}, "100sin(x)"); // 100sin(x) // ^ auto ill = parseFunction({"x"}, "x+k"); // x+k // ^ auto ill = parseFunction({"x", "y"}, "30*(x+y"); // 30*(x+y // ^
我一度猶豫要不要加個正規的 lexer,這樣錯誤處理會更方便且更精確,只是這樣一來程式碼勢必會膨脹許多。考慮到這只是一篇發表在部落格的概念演示,我想還是點到為止。(2016-08-28: 改用 std::sregex_token_iterator 改寫,不過仍然稱不上是完善的 lexer,而且回報錯誤位置變得有點 tricky)
對於直譯器技術有興趣的網友,可以從這篇了解 AST traversal 直譯器的基本概念。至於更進階的 bytecode 直譯器(register or stack based)以後有空或許我會寫幾篇文章來介紹。
#include <algorithm> #include <regex> #include <map> #include <vector> #include <iostream> #include <cmath> using namespace std; using ArgList = std::vector<double>; using Func = function<double(const ArgList&)>; using TokenIter = std::sregex_token_iterator; using VarTable = map<string, int>; const map<string, double(*)(double)> funcTable = { {"log", log}, {"sqrt", sqrt}, {"sin", sin}, {"cos", cos} }; Func parseExpr(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd); Func parseElem(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd) { if (it != tokEnd) { string token = *it; auto fn = funcTable.find(token); if (fn != funcTable.end()) { auto f = fn->second; auto sub = parseElem(varTable, ++it, tokEnd); return [f, sub] (const ArgList& a) { return f(sub(a)); }; } auto var = varTable.find(token); if (var != varTable.end()) { auto i = var->second; ++it; return [i] (const ArgList& a) { return a[i]; }; } if (token == "(") { auto sub = parseExpr(varTable, ++it, tokEnd); if (it != tokEnd && *it == ")") { ++it; return sub; } } else if (token == "-") { auto sub = parseElem(varTable, ++it, tokEnd); return [sub] (const ArgList& a) { return -sub(a); }; } else { double c = stod(token); ++it; return [c] (const ArgList& a) { return c; }; } } throw runtime_error("syntax error"); } Func parseProd(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd) { Func lhs = parseElem(varTable, it, tokEnd); while (it != tokEnd && (*it == "*" || *it == "/")) { string op = *it++; Func rhs = parseElem(varTable, it, tokEnd); lhs = (op == "*") ? Func([lhs, rhs](const ArgList& a) { return lhs(a) * rhs(a); }) : Func([lhs, rhs](const ArgList& a) { return lhs(a) / rhs(a); }) ; } return lhs; } Func parseExpr(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd) { Func lhs = parseProd(varTable, it, tokEnd); while (it != tokEnd && (*it == "+" || *it == "-")) { string op = *it++; Func rhs = parseProd(varTable, it, tokEnd); lhs = (op == "+") ? Func([lhs, rhs](const ArgList& a) { return lhs(a) + rhs(a); }) : Func([lhs, rhs](const ArgList& a) { return lhs(a) - rhs(a); }) ; } return lhs; } Func parseFunction(const vector<string>& varList, string src) { VarTable varTable; for (auto i = 0; i < varList.size(); ++i) varTable[varList[i]] = i; src.erase(remove_if(src.begin(), src.end(), ::isspace), src.end()); src.push_back(' '); regex rx("[[:alpha:]]+|[[:digit:].]+|[^[:alnum:].]"); auto it = TokenIter(src.begin(), src.end(), rx); try { auto expr = parseExpr(varTable, it, TokenIter()); if (it != TokenIter() && *it != ' ') throw runtime_error("syntax error"); return expr; } catch (exception& ex) { auto pos = (it != TokenIter()) ? it->first - src.begin() : 0; throw runtime_error(string("syntax error:\n ") + src + "\n" + string(pos + 4, ' ') + '^'); } } int main() { auto f1 = parseFunction({"x", "y"}, "1.5 + sqrt(x*x + y*y) / 2"); auto f2 = [](double x, double y) { return 1.5 + sqrt(x*x + y*y) / 2; }; cout << f1({3, 4}) << "\t" << f2( 3, 4 ) << "\n"; cout << f1({3.51, 2.18}) << "\t" << f2( 3.51, 2.18 ) << "\n"; cout << f1({-12.34, 56.789}) << "\t" << f2( -12.34, 56.789 ) << "\n"; return 0; }
留言列表