跳至主要內容

用户自定义函数

鲁老师大约 10 分钟Flink

提示

本教程已出版为《Flink原理与实践》,感兴趣的读者请在各大电商平台购买!




System Function给我们提供了大量内置功能,但对于一些特定领域或特定场景,System Function还远远不够,Flink提供了用户自定义函数功能,开发者可以实现一些特定的需求。用户自定义函数需要注册到Catalog中,因此这类函数又被称为Catalog Function。Catalog Function大大增强了Flink SQL的表达能力。

注册函数

在使用一个函数前,一般需要将这个函数注册到Catalog中。注册时需要调用TableEnvironment中的registerFunction方法。每个TableEnvironment都会有一个成员FunctionCatalogFunctionCatalog中存储了函数的定义,当注册函数时,实际上是将这个函数名和对应的实现写入到FunctionCatalog中。以注册一个ScalarFunction为例,它在源码中如下:

FunctionCatalog functionCatalog = ...

/**
  * 注册一个ScalaFunction
  * name: 函数名
  * function: 一个自定义的ScalaFunction
  */
public void registerFunction(String name, ScalarFunction function) {
		functionCatalog.registerTempSystemScalarFunction(
			name,
			function);
}

在Flink提供的System Function中,我们已经提到,内置的System Function提供了包括数学、比较、字符串、聚合等常用功能,如果这些内置的System Function无法满足我们的需求,我们可以使用Java、Scala和Python语言自定义一个函数。接下来我们将将详细讲解如何自定义函数以及如何使用函数。

标量函数

标量函数(Scalar Function)接收零个、一个或者多个输入,输出一个单值标量。这里以处理经纬度为例来展示如何自定义Scala Function。

当前,大量的应用极度依赖地理信息(Geographic Information):打车软件需要用户用户定位起点和终点、外卖平台需要确定用户送餐地点、运动类APP会记录用户的活动轨迹等。我们一般使用精度(Longitude)和纬度(Latitude)来标记一个地点。经纬度作为原始数据很难直接拿来分析,需要做一些转化,而Table API & SQL中没有相应的函数,因此需要我们自己来实现。

如果想自定义函数,我们需要继承org.apache.flink.table.functions.ScalarFunction类,实现eval方法。这与第四章介绍的DataStream API中算子自定义有异曲同工之处。假设我们需要判断一个经纬度数据是否在北京四环以内,可以使用Java实现下面的函数:

public class IsInFourRing extends ScalarFunction {

    // 北京四环经纬度范围
    private static double LON_EAST = 116.48;
    private static double LON_WEST = 116.27;
    private static double LAT_NORTH = 39.988;
    private static double LAT_SOUTH = 39.83;

    // 判断输入的经纬度是否在四环内
    public boolean eval(double lon, double lat) {
        return !(lon > LON_EAST || lon < LON_WEST) &&
                !(lat > LAT_NORTH || lat < LAT_SOUTH);
    }
}

在这个实现中,eval方法接收两个double类型的输入,对数据进行处理,生成一个boolean类型的输出。整个类中最重要的地方是eval方法,它决定了这个自定义函数的内在逻辑。自定义好函数之后,我们还需要用registerFunction方法将这个函数注册到Catalog中,并为之起名为IsInFourRing,这样就可以在SQL语句中使用IsInFourRing的这个名字进行计算了。

List<Tuple4<Long, Double, Double, Timestamp>> geoList = new ArrayList<>();
geoList.add(Tuple4.of(1L, 116.2775, 39.91132, Timestamp.valueOf("2020-03-06 00:00:00")));
geoList.add(Tuple4.of(2L, 116.44095, 39.88319, Timestamp.valueOf("2020-03-06 00:00:01")));
geoList.add(Tuple4.of(3L, 116.25965, 39.90478, Timestamp.valueOf("2020-03-06 00:00:02")));
geoList.add(Tuple4.of(4L, 116.27054, 39.87869, Timestamp.valueOf("2020-03-06 00:00:03")));

DataStream<Tuple4<Long, Double, Double, Timestamp>> geoStream = env
            .fromCollection(geoList)
            .assignTimestampsAndWatermarks(new AscendingTimestampExtractor<Tuple4<Long, Double, Double, Timestamp>>() {
                @Override
                public long extractAscendingTimestamp(Tuple4<Long, Double, Double, Timestamp> element) {
                    return element.f3.getTime();
                }
            });

// 创建表
Table geoTable = tEnv.fromDataStream(geoStream, "id, long, alt, ts.rowtime, proc.proctime");
tEnv.createTemporaryView("geo", geoTable);

// 注册函数到Catalog中,指定名字为IsInFourRing
tEnv.registerFunction("IsInFourRing", new IsInFourRing());

// 在SQL语句中使用IsInFourRing函数
Table inFourRingTab = tEnv.sqlQuery("SELECT id FROM geo WHERE IsInFourRing(long, alt)");

我们也可以利用编程语言的重载特性,针对不同类型的输入设计不同的函数。假如经纬度参数以float或者String形式传入,为了适应这些输入,可以实现多个eval方法,让编译器帮忙做重载:

public boolean eval(double lon, double lat) {
    return !(lon > LON_EAST || lon < LON_WEST) &&
      			!(lat > LAT_NORTH || lat < LAT_SOUTH);
}

public boolean eval(float lon, float lat) {
    return !(lon > LON_EAST || lon < LON_WEST) &&
            !(lat > LAT_NORTH || lat < LAT_SOUTH);
}

public boolean eval(String lonStr, String latStr) {
    double lon = Double.parseDouble(lonStr);
    double lat = Double.parseDouble(latStr);
    return !(lon > LON_EAST || lon < LON_WEST) &&
            !(lat > LAT_NORTH || lat < LAT_SOUTH);
}

eval方法的输入和输出类型决定了ScalarFunction的输入输出类型。在具体的执行过程中,Flink的类型系统会自动推测输入和输出类型,一些无法被自动推测的类型可以使用DataTypeHint来提示Flink使用哪种输入输出类型。下面的代码接收两个Timestamp作为输入,返回两个时间戳之间的差,用DataTypeHint来提示将返回结果转化为BIGINT类型。

public class TimeDiff extends ScalarFunction {

    public @DataTypeHint("BIGINT") long eval(Timestamp first, Timestamp second) {
        return java.time.Duration.between(first.toInstant(), second.toInstant()).toMillis();
    }
}

DataTypeHint一般可以满足绝大多数的需求,如果类型仍然复杂,开发者可以自己重写UserDefinedFunction#getTypeInference(DataTypeFactory)方法,返回合适的类型。

表函数

另一种常见的用户自定义函数为表函数(Table Function)。Table Function能够接收零到多个标量输入,与Scalar Function不同的是,Table Function输出零到多行,每行数据一到多列。从这些特征来看,Table Function更像是一个表,一般出现在FROM之后。我们在Temporal Table Join中提到的Temporal Table就是一种Table Function。

为了定义Table Function,我们需要继承org.apache.flink.table.functions.TableFunction类,然后实现eval方法,这与Scalar Function几乎一致。同样,我们可以利用重载,实现一到多个eval方法。与Scala Function中只输出一个标量不同,Table Function可以输出零到多行,eval方法里使用collect方法将结果输出,输出的数据类型由TableFunction<T>中的泛型T决定。

下面的代码将字符串输入按照#切分,输出零到多行,输出类型为String

public class TableFunc extends TableFunction<String> {

    // 按#切分字符串,输出零到多行
    public void eval(String str) {
        if (str.contains("#")) {
            String[] arr = str.split("#");
            for (String i: arr) {
                collect(i);
            }
        }
    }
}

在主逻辑中,我们需要使用registerFunction方法注册函数,并指定一个名字。在SQL语句中,使用LATERAL TABLE(<TableFunctionName>)来调用这个Table Function。

List<Tuple4<Integer, Long, String, Timestamp>> list = new ArrayList<>();
list.add(Tuple4.of(1, 1L, "Jack#22", Timestamp.valueOf("2020-03-06 00:00:00")));
list.add(Tuple4.of(2, 2L, "John#19", Timestamp.valueOf("2020-03-06 00:00:01")));
list.add(Tuple4.of(3, 3L, "nosharp", Timestamp.valueOf("2020-03-06 00:00:03")));

DataStream<Tuple4<Integer, Long, String, Timestamp>> stream = env
            .fromCollection(list)
            .assignTimestampsAndWatermarks(new AscendingTimestampExtractor<Tuple4<Integer, Long, String, Timestamp>>() {
                @Override
                public long extractAscendingTimestamp(Tuple4<Integer, Long, String, Timestamp> element) {
                    return element.f3.getTime();
                }
            });
// 获取Table
Table table = tEnv.fromDataStream(stream, "id, long, str, ts.rowtime");
tEnv.createTemporaryView("input_table", table);

// 注册函数到Catalog中,指定名字为Func
tEnv.registerFunction("Func", new TableFunc());

// input_table与LATERAL TABLE(Func(str))进行JOIN
Table tableFunc = tEnv.sqlQuery("SELECT id, s FROM input_table, LATERAL TABLE(Func(str)) AS T(s)");

在这个例子中,LATERAL TABLE(Func(str))接受input_table中字段str作为输入,被命名为一个新表,名为TT中有一个字段s,s是我们刚刚自定义的TableFunc的输出。本例中,input_tableLATERAL TABLE(Func(str))之间使用逗号,隔开,实际上这两个表是按照CROSS JOIN方式连接起来的,或者说,这两个表在做笛卡尔积,这个SQL语句返回值为:

1,22
1,Jack
2,19
2,John

我们也可以使用其他类型的JOIN,比如LEFT JOIN

// input_table与LATERAL TABLE(Func(str))进行LEFT JOIN
Table joinTableFunc = tEnv.sqlQuery("SELECT id, s FROM input_table LEFT JOIN LATERAL TABLE(Func(str)) AS T(s) ON TRUE");

ON TRUE条件表示所有左侧表中的数据都与右侧进行Join,因此结果中多出了一行3,null

1,22
1,Jack
2,19
2,John
3,null

聚合函数

在System Function中我们曾介绍了聚合函数,聚合函数一般将多行数据进行聚合,输出一个标量。常用的聚合函数有COUNTSUM等。对于一些特定问题,这些内置函数可能无法满足需求,在Flink SQL中,用户可以对聚合函数进行用户自定义,这种函数被称为用户自定义聚合函数(User-Defined Aggregate Function)。

假设我们的表中有下列字段:id、数值v、权重w,我们对id进行GROUP BY,计算v的加权平均值。计算的过程如下表所示。

用户自定义聚合函数:求加权平均
用户自定义聚合函数:求加权平均

下面的代码实现了一个加权平均函数WeightedAvg,这个函数接收两个Long类型的输入,返回一个Double类型的输出。计算过程基于累加器WeightedAvgAccum,它记录了当前加权和sum以及权重weight

import org.apache.flink.table.functions.AggregateFunction;
import java.util.Iterator;

/**
 * 加权平均函数
 */
public class WeightedAvg extends AggregateFunction<Double, WeightedAvg.WeightedAvgAccum> {

    @Override
    public WeightedAvgAccum createAccumulator() {
        return new WeightedAvgAccum();
    }

    // 需要物化输出时,getValue方法会被调用
    @Override
    public Double getValue(WeightedAvgAccum acc) {
        if (acc.weight == 0) {
            return null;
        } else {
            return (double) acc.sum / acc.weight;
        }
    }

    // 新数据到达时,更新ACC
    public void accumulate(WeightedAvgAccum acc, long iValue, long iWeight) {
        acc.sum += iValue * iWeight;
        acc.weight += iWeight;
    }

    // 用于BOUNDED OVER WINDOW,将较早的数据剔除
    public void retract(WeightedAvgAccum acc, long iValue, long iWeight) {
        acc.sum -= iValue * iWeight;
        acc.weight -= iWeight;
    }

    // 将多个ACC合并为一个ACC
    public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
        Iterator<WeightedAvgAccum> iter = it.iterator();
        while (iter.hasNext()) {
            WeightedAvgAccum a = iter.next();
            acc.weight += a.weight;
            acc.sum += a.sum;
        }
    }

    // 重置ACC
    public void resetAccumulator(WeightedAvgAccum acc) {
        acc.weight = 0l;
        acc.sum = 0l;
    }

    /**
     * 累加器 Accumulator
     * sum: 和
     * weight: 权重
     */
    public static class WeightedAvgAccum {
        public long sum = 0;
        public long weight = 0;
    }
}

从这个例子我们可以看到,自定义聚合函数时,我们需要继承org.apache.flink.table.functions.AggregateFunction类。注意,这个类与DataStream API的窗口算子中所介绍的AggregateFunction命名空间不同,在引用时不要写错。不过这两个AggregateFunction的工作原理大同小异。首先,AggregateFunction调用createAccumulator方法创建一个累加器,这里简称ACC,ACC用来存储中间结果。接着,每当表中有新数据到达,Flink SQL会调用accumulate方法,新数据会作用在ACC上,ACC被更新。当一个分组的所有数据都被accumulate处理,getValue方法可以将ACC中的中间结果输出。

综上,定义一个AggregateFunction时,这三个方法是必须实现的:

  • createAccumulator:创建ACC,可以使用一个自定义的数据结构。
  • accumulate:处理新流入数据,更新ACC;第一个参数是ACC,第二个以及以后的参数为流入数据。
  • getValue:输出结果,返回值的数据类型T与AggregateFunction<T>中定义的泛型T保持一致。

createAccumulator创建一个ACC。accumulate第一个参数为ACC,第二个及以后的参数为整个AggregateFunction的输入参数,这个方法的作用就是接受输入,并将输入作用到ACC上,更新ACC。getValue返回值的类型T为整个AggregateFunction<T>的输出类型。

除了上面三个方法,下面三个方法需要根据使用情况来决定是否需要定义。例如,在流处理的会话窗口上进行聚合时,必须定义merge方法,因为当发现某行数据恰好可以将两个窗口连接为一个窗口时,merge方法可以将两个窗口内的ACC合并。

  • retract:有界OVER WINDOW场景上,窗口是有界的,需要将早期的数据剔除。
  • merge:将多个ACC合并为一个ACC,常用在流处理的会话窗口分组和批处理分组上。
  • resetAccumulator:重置ACC,用于批处理分组上。

这些方法必须声明为public,且不能是static的,方法名必须与上述名字保持一致。

在主逻辑中,我们注册这个函数,并在SQL语句中使用它:

List<Tuple4<Integer, Long, Long, Timestamp>> list = new ArrayList<>();
list.add(Tuple4.of(1, 100l, 1l, Timestamp.valueOf("2020-03-06 00:00:00")));
list.add(Tuple4.of(1, 200l, 2l, Timestamp.valueOf("2020-03-06 00:00:01")));
list.add(Tuple4.of(3, 300l, 3l, Timestamp.valueOf("2020-03-06 00:00:13")));

DataStream<Tuple4<Integer, Long, Long, Timestamp>> stream = env
            .fromCollection(list)
            .assignTimestampsAndWatermarks(new AscendingTimestampExtractor<Tuple4<Integer, Long, Long, Timestamp>>() {
                @Override
                public long extractAscendingTimestamp(Tuple4<Integer, Long, Long, Timestamp> element) {
                    return element.f3.getTime();
                }
            });

Table table = tEnv.fromDataStream(stream, "id, v, w, ts.rowtime");

tEnv.createTemporaryView("input_table", table);

tEnv.registerFunction("WeightAvg", new WeightedAvg());

Table agg = tEnv.sqlQuery("SELECT id, WeightAvg(v, w) FROM input_table GROUP BY id");