前幾天有網友回覆 四則運算解析器 那篇,回頭瞄了一下舊程式碼剛好讓我得到了一個靈感,所以寫了這篇「函數解析器」。

我過去曾用 C++03 實作過一些小型語言的編譯器、直譯器,使用都是比較傳統的方法,也就是設計一個 AST 節點的基礎類別,再特化出各種不同類型的 AST 節點。這個寫法非常的囉唆,許多程式碼都是為了滿足靜態型別語言的規範,而不是實現真正的功能,相較之下 python、javascript 之類的語言可以用精簡許多的程式碼完成同樣的事情。

我最初的想法是,既然 C++11 有了 std::function,那麼就不需要透過基礎類別來提供共同界面,接著我又想到了可以用 std::bind 來連結子節點。最後靈光乍現,只要有 lambda,根本就沒有必要用 std::bind。對了,這個程式使用了 std::regex,最好使用最新的編譯器,已知在 GCC 4.8 會有錯誤,至於 GCC 5 以後和近幾版的 clang 應該都沒問題。

這個函數解析器的特點:

  • 可以處理運算子優現順序和括號。
  • 支援 log、sin、sqrt 等數學函數。
  • 接受多參數。

簡單介紹一下使用方法:

auto foo = parseFunction({"x", "y"}, "sqrt(x*x + y*y)");
cout << foo({3, 4}) << "\n";
// 5

auto bar = parseFunction({"r", "pi"}, "r * r * pi");
cout << bar({10, 3.14159}) << endl;
// 314.159

這個程式能透過 exceptions 回報某些語法錯誤,並指出發生位置,雖然仍然不完善。

auto ill = parseFunction({"x"}, "100sin(x)");
// 100sin(x)
//    ^

auto ill = parseFunction({"x"}, "x+k");
// x+k
//   ^

auto ill = parseFunction({"x", "y"}, "30*(x+y");
// 30*(x+y
//        ^

我一度猶豫要不要加個正規的 lexer,這樣錯誤處理會更方便且更精確,只是這樣一來程式碼勢必會膨脹許多。考慮到這只是一篇發表在部落格的概念演示,我想還是點到為止。(2016-08-28: 改用 std::sregex_token_iterator 改寫,不過仍然稱不上是完善的 lexer,而且回報錯誤位置變得有點 tricky)

對於直譯器技術有興趣的網友,可以從這篇了解 AST traversal 直譯器的基本概念。至於更進階的 bytecode 直譯器(register or stack based)以後有空或許我會寫幾篇文章來介紹。


#include <algorithm>
#include <regex>
#include <map>
#include <vector>
#include <iostream>
#include <cmath>

using namespace std;

using ArgList = std::vector<double>;
using Func = function<double(const ArgList&)>;
using TokenIter = std::sregex_token_iterator;
using VarTable = map<string, int>;

const map<string, double(*)(double)> funcTable = {
    {"log", log}, {"sqrt", sqrt}, {"sin", sin}, {"cos", cos}
};

Func parseExpr(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd);

Func parseElem(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd) {
    if (it != tokEnd) {
        string token = *it;
        auto fn = funcTable.find(token);
        if (fn != funcTable.end()) {
            auto f = fn->second;
            auto sub = parseElem(varTable, ++it, tokEnd);
            return [f, sub] (const ArgList& a) { return f(sub(a)); };
        }
        auto var = varTable.find(token);
        if (var != varTable.end()) {
            auto i = var->second;
            ++it;
            return [i] (const ArgList& a) { return a[i]; };
        }
        if (token == "(") {
            auto sub = parseExpr(varTable, ++it, tokEnd);
            if (it != tokEnd && *it == ")") {
                ++it;
                return sub;
            }
        } else if (token == "-") {
            auto sub = parseElem(varTable, ++it, tokEnd);
            return [sub] (const ArgList& a) { return -sub(a); };
        } else {
            double c = stod(token);
            ++it;
            return [c] (const ArgList& a) {  return c; };
        }
    }
    throw runtime_error("syntax error");
}

Func parseProd(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd) {
    Func lhs = parseElem(varTable, it, tokEnd);
    while (it != tokEnd && (*it == "*" || *it == "/")) {
        string op = *it++;
        Func rhs = parseElem(varTable, it, tokEnd);
        lhs = (op == "*") ?
                Func([lhs, rhs](const ArgList& a) { return lhs(a) * rhs(a); }) :
                Func([lhs, rhs](const ArgList& a) { return lhs(a) / rhs(a); }) ;
    }
    return lhs;
}

Func parseExpr(const VarTable& varTable, TokenIter& it, const TokenIter& tokEnd) {
    Func lhs = parseProd(varTable, it, tokEnd);
    while (it != tokEnd && (*it == "+" || *it == "-")) {
        string op = *it++;
        Func rhs = parseProd(varTable, it, tokEnd);
        lhs = (op == "+") ?
                Func([lhs, rhs](const ArgList& a) { return lhs(a) + rhs(a); }) :
                Func([lhs, rhs](const ArgList& a) { return lhs(a) - rhs(a); }) ;
    }
    return lhs;
}

Func parseFunction(const vector<string>& varList, string src) {
    VarTable varTable;
    for (auto i = 0; i < varList.size(); ++i)
        varTable[varList[i]] = i;

    src.erase(remove_if(src.begin(), src.end(), ::isspace), src.end());
    src.push_back(' ');
    regex rx("[[:alpha:]]+|[[:digit:].]+|[^[:alnum:].]");
    auto it = TokenIter(src.begin(), src.end(), rx);
    try {
        auto expr = parseExpr(varTable, it, TokenIter());
        if (it !=  TokenIter() && *it != ' ')
            throw runtime_error("syntax error");
        return expr;
    } catch (exception& ex) {
        auto pos = (it !=  TokenIter()) ? it->first - src.begin() : 0;
        throw runtime_error(string("syntax error:\n    ") + src + "\n"
                + string(pos + 4, ' ') + '^');
    }
}

int main() {
    auto f1 = parseFunction({"x", "y"}, "1.5 + sqrt(x*x + y*y) / 2");
    auto f2 = [](double x, double y) { return 1.5 + sqrt(x*x + y*y) / 2; };
    cout << f1({3, 4}) << "\t"
         << f2( 3, 4 ) << "\n";
    cout << f1({3.51, 2.18}) << "\t"
         << f2( 3.51, 2.18 ) << "\n";
    cout << f1({-12.34, 56.789}) << "\t"
         << f2( -12.34, 56.789 ) << "\n";
    return 0;
}
arrow
arrow
    文章標籤
    c++11 c++14
    全站熱搜

    novus 發表在 痞客邦 留言(1) 人氣()