Skip to content

Commit 89ec2cb

Browse files
albertshauvravish
authored andcommitted
CDAP-14107 fix bug with RDD functions
Fixing a bug that was introduced during refactoring for the scala spark sink where RDD functions were not being handled correctly. In compute plugins, we were assuming the return value was a DataFrame, and would end up calling .toJavaRDD() on a JavaRDD. In sink plugins, we were assuming the return value was an RDD instead of Unit, which would result in a null pointer exception. Added correctly handling so that we only try to convert a DataFrame to an RDD if it is actually a DataFrame, and we don't try to treat a null as an RDD.
1 parent 848217c commit 89ec2cb

File tree

3 files changed

+93
-30
lines changed

3 files changed

+93
-30
lines changed

src/main/java/io/cdap/plugin/spark/dynamic/ScalaSparkCodeExecutor.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ public void onEvent(SparkListenerEvent event) {
165165
}
166166
}
167167

168+
public boolean isDataFrame() {
169+
return isDataFrame;
170+
}
171+
168172
/**
169173
* Execute interpreted code on the given RDD.
170174
*/
@@ -177,10 +181,10 @@ public Object execute(SparkExecutionPluginContext context,
177181
if (!isDataFrame) {
178182
if (takeContext) {
179183
//noinspection unchecked
180-
return ((RDD<StructuredRecord>) method.invoke(null, javaRDD.rdd(), context)).toJavaRDD();
184+
return method.invoke(null, javaRDD.rdd(), context);
181185
} else {
182186
//noinspection unchecked
183-
return ((RDD<StructuredRecord>) method.invoke(null, javaRDD.rdd())).toJavaRDD();
187+
return method.invoke(null, javaRDD.rdd());
184188
}
185189
}
186190

src/main/java/io/cdap/plugin/spark/dynamic/ScalaSparkCompute.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import io.cdap.cdap.etl.api.batch.SparkExecutionPluginContext;
3131
import org.apache.spark.api.java.JavaRDD;
3232
import org.apache.spark.api.java.function.Function;
33+
import org.apache.spark.rdd.RDD;
3334
import org.apache.spark.sql.Row;
3435
import org.apache.spark.sql.types.DataType;
3536

@@ -48,6 +49,7 @@ public class ScalaSparkCompute extends SparkCompute<StructuredRecord, Structured
4849
// A strong reference is needed to keep the compiled classes around
4950
@SuppressWarnings("FieldCanBeLocal")
5051
private transient ScalaSparkCodeExecutor codeExecutor;
52+
private transient boolean isRDD;
5153

5254
public ScalaSparkCompute(Config config) {
5355
this.config = config;
@@ -77,13 +79,19 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
7779
public void initialize(SparkExecutionPluginContext context) throws Exception {
7880
codeExecutor = new ScalaSparkCodeExecutor(config.getScalaCode(), config.getDependencies(), "transform", false);
7981
codeExecutor.initialize(context);
82+
isRDD = !codeExecutor.isDataFrame();
8083
}
8184

8285
@Override
8386
public JavaRDD<StructuredRecord> transform(SparkExecutionPluginContext context,
8487
JavaRDD<StructuredRecord> javaRDD) throws Exception {
8588
Object result = codeExecutor.execute(context, javaRDD);
8689

90+
if (isRDD) {
91+
//noinspection unchecked
92+
return ((RDD<StructuredRecord>) result).toJavaRDD();
93+
}
94+
8795
// Convert the DataFrame back to RDD<StructureRecord>
8896
Schema outputSchema = context.getOutputSchema();
8997
if (outputSchema == null) {
@@ -176,7 +184,7 @@ public RowToRecord(Schema schema) {
176184
}
177185

178186
@Override
179-
public StructuredRecord call(Row row) throws Exception {
187+
public StructuredRecord call(Row row) {
180188
return DataFrames.fromRow(row, schema);
181189
}
182190
}

src/test/java/io/cdap/plugin/spark/dynamic/ScalaSparkTest.java

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
import java.util.List;
7373
import java.util.Map;
7474
import java.util.Set;
75+
import java.util.UUID;
7576
import java.util.concurrent.TimeUnit;
7677

7778
/**
@@ -269,7 +270,40 @@ public void testScalaSparkProgramClosure() throws Exception {
269270
}
270271

271272
@Test
272-
public void testScalaSparkCompute() throws Exception {
273+
public void testScalaSparkComputeDataFrame() throws Exception {
274+
StringWriter codeWriter = new StringWriter();
275+
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
276+
printer.println("def transform(df: DataFrame) : DataFrame = {");
277+
printer.println(" val splitted = df.explode(\"body\", \"word\") { ");
278+
printer.println(" line: String => line.split(\"\\\\s+\")");
279+
printer.println(" }");
280+
printer.println(" splitted.registerTempTable(\"splitted\")");
281+
printer.println(" splitted.sqlContext.sql(\"SELECT word, count(*) as count FROM splitted GROUP BY word\")");
282+
printer.println("}");
283+
}
284+
285+
testWordCountCompute(codeWriter.toString());
286+
}
287+
288+
@Test
289+
public void testScalaSparkComputeRDD() throws Exception {
290+
StringWriter codeWriter = new StringWriter();
291+
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
292+
printer.println(
293+
"def transform(rdd: RDD[StructuredRecord], context:SparkExecutionPluginContext) : RDD[StructuredRecord] = {");
294+
printer.println(" val schema = context.getOutputSchema");
295+
printer.println(" rdd");
296+
printer.println(" .flatMap(_.get[String](\"body\").split(\"\\\\s+\"))");
297+
printer.println(" .map(s => (s, 1L))");
298+
printer.println(" .reduceByKey(_ + _)");
299+
printer.println(" .map(t => StructuredRecord.builder(schema).set(\"word\", t._1).set(\"count\", t._2).build)");
300+
printer.println("}");
301+
}
302+
303+
testWordCountCompute(codeWriter.toString());
304+
}
305+
306+
private void testWordCountCompute(String code) throws Exception {
273307
Schema inputSchema = Schema.recordOf(
274308
"input",
275309
Schema.Field.of("body", Schema.nullableOf(Schema.of(Schema.Type.STRING)))
@@ -281,25 +315,17 @@ public void testScalaSparkCompute() throws Exception {
281315
Schema.Field.of("count", Schema.nullableOf(Schema.of(Schema.Type.LONG)))
282316
);
283317

284-
StringWriter codeWriter = new StringWriter();
285-
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
286-
printer.println("def transform(df: DataFrame) : DataFrame = {");
287-
printer.println(" val splitted = df.explode(\"body\", \"word\") { ");
288-
printer.println(" line: String => line.split(\"\\\\s+\")");
289-
printer.println(" }");
290-
printer.println(" splitted.registerTempTable(\"splitted\")");
291-
printer.println(" splitted.sqlContext.sql(\"SELECT word, count(*) as count FROM splitted GROUP BY word\")");
292-
printer.println("}");
293-
}
318+
String inputTable = UUID.randomUUID().toString();
319+
String outputTable = UUID.randomUUID().toString();
294320

295321
// Pipeline configuration
296322
ETLBatchConfig etlConfig = ETLBatchConfig.builder("* * * * *")
297-
.addStage(new ETLStage("source", MockSource.getPlugin("singleInput", inputSchema)))
323+
.addStage(new ETLStage("source", MockSource.getPlugin(inputTable, inputSchema)))
298324
.addStage(new ETLStage("compute", new ETLPlugin("ScalaSparkCompute", SparkCompute.PLUGIN_TYPE, ImmutableMap.of(
299-
"scalaCode", codeWriter.toString(),
325+
"scalaCode", code,
300326
"schema", computeSchema.toString()
301327
))))
302-
.addStage(new ETLStage("sink", MockSink.getPlugin("singleOutput")))
328+
.addStage(new ETLStage("sink", MockSink.getPlugin(outputTable)))
303329
.addConnection("source", "compute")
304330
.addConnection("compute", "sink")
305331
.build();
@@ -308,11 +334,11 @@ public void testScalaSparkCompute() throws Exception {
308334
ArtifactSummary artifactSummary = new ArtifactSummary(DATAPIPELINE_ARTIFACT_ID.getArtifact(),
309335
DATAPIPELINE_ARTIFACT_ID.getVersion());
310336
AppRequest<ETLBatchConfig> appRequest = new AppRequest<>(artifactSummary, etlConfig);
311-
ApplicationId appId = NamespaceId.DEFAULT.app("ScalaSparkComputeApp");
337+
ApplicationId appId = NamespaceId.DEFAULT.app(UUID.randomUUID().toString());
312338
ApplicationManager appManager = deployApplication(appId, appRequest);
313339

314340
// write records to source
315-
DataSetManager<Table> inputManager = getDataset(NamespaceId.DEFAULT.dataset("singleInput"));
341+
DataSetManager<Table> inputManager = getDataset(NamespaceId.DEFAULT.dataset(inputTable));
316342
List<StructuredRecord> inputRecords = new ArrayList<>();
317343
for (int i = 0; i < 10; i++) {
318344
inputRecords.add(StructuredRecord.builder(inputSchema).set("body", "Line " + i).build());
@@ -326,7 +352,7 @@ public void testScalaSparkCompute() throws Exception {
326352

327353
// Verify result written to sink.
328354
// It has two fields, word and count.
329-
DataSetManager<Table> sinkManager = getDataset("singleOutput");
355+
DataSetManager<Table> sinkManager = getDataset(outputTable);
330356
Map<String, StructuredRecord> wordCounts =
331357
Maps.uniqueIndex(Sets.newHashSet(MockSink.readOutput(sinkManager)), new Function<StructuredRecord, String>() {
332358
@Override
@@ -343,13 +369,28 @@ public String apply(StructuredRecord record) {
343369
}
344370

345371
@Test
346-
public void testScalaSparkSink() throws Exception {
347-
Schema inputSchema = Schema.recordOf(
348-
"input",
349-
Schema.Field.of("body", Schema.nullableOf(Schema.of(Schema.Type.STRING)))
350-
);
372+
public void testScalaSparkSinkRDD() throws Exception {
373+
File testFolder = TEMP_FOLDER.newFolder("scalaSinkRDDOutput");
374+
File outputFolder = new File(testFolder, "output");
375+
StringWriter codeWriter = new StringWriter();
376+
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
377+
printer.println(
378+
"def sink(rdd: RDD[StructuredRecord], context:SparkExecutionPluginContext) : Unit = {");
379+
printer.println(" val schema = context.getOutputSchema");
380+
printer.println(" rdd");
381+
printer.println(" .flatMap(_.get[String](\"body\").split(\"\\\\s+\"))");
382+
printer.println(" .map(s => (s, 1L))");
383+
printer.println(" .reduceByKey(_ + _)");
384+
printer.println(" .map(t => t._1 + \" \" + t._2)");
385+
printer.println(" .saveAsTextFile(\"" + outputFolder.getAbsolutePath() + "\")");
386+
printer.println("}");
387+
}
388+
testWordCountSink(codeWriter.toString(), outputFolder);
389+
}
351390

352-
File testFolder = TEMP_FOLDER.newFolder("scalaSinkOutput");
391+
@Test
392+
public void testScalaSparkSinkDataFrame() throws Exception {
393+
File testFolder = TEMP_FOLDER.newFolder("scalaSinkDataframeOutput");
353394
File outputFolder = new File(testFolder, "output");
354395
StringWriter codeWriter = new StringWriter();
355396
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
@@ -363,24 +404,34 @@ public void testScalaSparkSink() throws Exception {
363404
printer.println(" out.write.format(\"text\").save(\"" + outputFolder.getAbsolutePath() + "\")");
364405
printer.println("}");
365406
}
407+
testWordCountSink(codeWriter.toString(), outputFolder);
408+
}
409+
410+
private void testWordCountSink(String code, File outputFolder) throws Exception {
411+
Schema inputSchema = Schema.recordOf(
412+
"input",
413+
Schema.Field.of("body", Schema.nullableOf(Schema.of(Schema.Type.STRING)))
414+
);
415+
416+
String inputTable = UUID.randomUUID().toString();
366417

367418
// Pipeline configuration
368419
ETLBatchConfig etlConfig = ETLBatchConfig.builder("* * * * *")
369-
.addStage(new ETLStage("source", MockSource.getPlugin("sinkInput", inputSchema)))
420+
.addStage(new ETLStage("source", MockSource.getPlugin(inputTable, inputSchema)))
370421
.addStage(new ETLStage("sink", new ETLPlugin("ScalaSparkSink", SparkSink.PLUGIN_TYPE,
371-
ImmutableMap.of("scalaCode", codeWriter.toString()))))
422+
ImmutableMap.of("scalaCode", code))))
372423
.addConnection("source", "sink")
373424
.build();
374425

375426
// Deploy the pipeline
376427
ArtifactSummary artifactSummary = new ArtifactSummary(DATAPIPELINE_ARTIFACT_ID.getArtifact(),
377428
DATAPIPELINE_ARTIFACT_ID.getVersion());
378429
AppRequest<ETLBatchConfig> appRequest = new AppRequest<>(artifactSummary, etlConfig);
379-
ApplicationId appId = NamespaceId.DEFAULT.app("ScalaSparkSinkApp");
430+
ApplicationId appId = NamespaceId.DEFAULT.app(UUID.randomUUID().toString());
380431
ApplicationManager appManager = deployApplication(appId, appRequest);
381432

382433
// write records to source
383-
DataSetManager<Table> inputManager = getDataset(NamespaceId.DEFAULT.dataset("sinkInput"));
434+
DataSetManager<Table> inputManager = getDataset(NamespaceId.DEFAULT.dataset(inputTable));
384435
List<StructuredRecord> inputRecords = new ArrayList<>();
385436
for (int i = 0; i < 10; i++) {
386437
inputRecords.add(StructuredRecord.builder(inputSchema).set("body", "Line " + i).build());

0 commit comments

Comments
 (0)