7272import java .util .List ;
7373import java .util .Map ;
7474import java .util .Set ;
75+ import java .util .UUID ;
7576import java .util .concurrent .TimeUnit ;
7677
7778/**
@@ -269,7 +270,40 @@ public void testScalaSparkProgramClosure() throws Exception {
269270 }
270271
271272 @ Test
272- public void testScalaSparkCompute () throws Exception {
273+ public void testScalaSparkComputeDataFrame () throws Exception {
274+ StringWriter codeWriter = new StringWriter ();
275+ try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
276+ printer .println ("def transform(df: DataFrame) : DataFrame = {" );
277+ printer .println (" val splitted = df.explode(\" body\" , \" word\" ) { " );
278+ printer .println (" line: String => line.split(\" \\ \\ s+\" )" );
279+ printer .println (" }" );
280+ printer .println (" splitted.registerTempTable(\" splitted\" )" );
281+ printer .println (" splitted.sqlContext.sql(\" SELECT word, count(*) as count FROM splitted GROUP BY word\" )" );
282+ printer .println ("}" );
283+ }
284+
285+ testWordCountCompute (codeWriter .toString ());
286+ }
287+
288+ @ Test
289+ public void testScalaSparkComputeRDD () throws Exception {
290+ StringWriter codeWriter = new StringWriter ();
291+ try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
292+ printer .println (
293+ "def transform(rdd: RDD[StructuredRecord], context:SparkExecutionPluginContext) : RDD[StructuredRecord] = {" );
294+ printer .println (" val schema = context.getOutputSchema" );
295+ printer .println (" rdd" );
296+ printer .println (" .flatMap(_.get[String](\" body\" ).split(\" \\ \\ s+\" ))" );
297+ printer .println (" .map(s => (s, 1L))" );
298+ printer .println (" .reduceByKey(_ + _)" );
299+ printer .println (" .map(t => StructuredRecord.builder(schema).set(\" word\" , t._1).set(\" count\" , t._2).build)" );
300+ printer .println ("}" );
301+ }
302+
303+ testWordCountCompute (codeWriter .toString ());
304+ }
305+
306+ private void testWordCountCompute (String code ) throws Exception {
273307 Schema inputSchema = Schema .recordOf (
274308 "input" ,
275309 Schema .Field .of ("body" , Schema .nullableOf (Schema .of (Schema .Type .STRING )))
@@ -281,25 +315,17 @@ public void testScalaSparkCompute() throws Exception {
281315 Schema .Field .of ("count" , Schema .nullableOf (Schema .of (Schema .Type .LONG )))
282316 );
283317
284- StringWriter codeWriter = new StringWriter ();
285- try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
286- printer .println ("def transform(df: DataFrame) : DataFrame = {" );
287- printer .println (" val splitted = df.explode(\" body\" , \" word\" ) { " );
288- printer .println (" line: String => line.split(\" \\ \\ s+\" )" );
289- printer .println (" }" );
290- printer .println (" splitted.registerTempTable(\" splitted\" )" );
291- printer .println (" splitted.sqlContext.sql(\" SELECT word, count(*) as count FROM splitted GROUP BY word\" )" );
292- printer .println ("}" );
293- }
318+ String inputTable = UUID .randomUUID ().toString ();
319+ String outputTable = UUID .randomUUID ().toString ();
294320
295321 // Pipeline configuration
296322 ETLBatchConfig etlConfig = ETLBatchConfig .builder ("* * * * *" )
297- .addStage (new ETLStage ("source" , MockSource .getPlugin ("singleInput" , inputSchema )))
323+ .addStage (new ETLStage ("source" , MockSource .getPlugin (inputTable , inputSchema )))
298324 .addStage (new ETLStage ("compute" , new ETLPlugin ("ScalaSparkCompute" , SparkCompute .PLUGIN_TYPE , ImmutableMap .of (
299- "scalaCode" , codeWriter . toString () ,
325+ "scalaCode" , code ,
300326 "schema" , computeSchema .toString ()
301327 ))))
302- .addStage (new ETLStage ("sink" , MockSink .getPlugin ("singleOutput" )))
328+ .addStage (new ETLStage ("sink" , MockSink .getPlugin (outputTable )))
303329 .addConnection ("source" , "compute" )
304330 .addConnection ("compute" , "sink" )
305331 .build ();
@@ -308,11 +334,11 @@ public void testScalaSparkCompute() throws Exception {
308334 ArtifactSummary artifactSummary = new ArtifactSummary (DATAPIPELINE_ARTIFACT_ID .getArtifact (),
309335 DATAPIPELINE_ARTIFACT_ID .getVersion ());
310336 AppRequest <ETLBatchConfig > appRequest = new AppRequest <>(artifactSummary , etlConfig );
311- ApplicationId appId = NamespaceId .DEFAULT .app ("ScalaSparkComputeApp" );
337+ ApplicationId appId = NamespaceId .DEFAULT .app (UUID . randomUUID (). toString () );
312338 ApplicationManager appManager = deployApplication (appId , appRequest );
313339
314340 // write records to source
315- DataSetManager <Table > inputManager = getDataset (NamespaceId .DEFAULT .dataset ("singleInput" ));
341+ DataSetManager <Table > inputManager = getDataset (NamespaceId .DEFAULT .dataset (inputTable ));
316342 List <StructuredRecord > inputRecords = new ArrayList <>();
317343 for (int i = 0 ; i < 10 ; i ++) {
318344 inputRecords .add (StructuredRecord .builder (inputSchema ).set ("body" , "Line " + i ).build ());
@@ -326,7 +352,7 @@ public void testScalaSparkCompute() throws Exception {
326352
327353 // Verify result written to sink.
328354 // It has two fields, word and count.
329- DataSetManager <Table > sinkManager = getDataset ("singleOutput" );
355+ DataSetManager <Table > sinkManager = getDataset (outputTable );
330356 Map <String , StructuredRecord > wordCounts =
331357 Maps .uniqueIndex (Sets .newHashSet (MockSink .readOutput (sinkManager )), new Function <StructuredRecord , String >() {
332358 @ Override
@@ -343,13 +369,28 @@ public String apply(StructuredRecord record) {
343369 }
344370
345371 @ Test
346- public void testScalaSparkSink () throws Exception {
347- Schema inputSchema = Schema .recordOf (
348- "input" ,
349- Schema .Field .of ("body" , Schema .nullableOf (Schema .of (Schema .Type .STRING )))
350- );
372+ public void testScalaSparkSinkRDD () throws Exception {
373+ File testFolder = TEMP_FOLDER .newFolder ("scalaSinkRDDOutput" );
374+ File outputFolder = new File (testFolder , "output" );
375+ StringWriter codeWriter = new StringWriter ();
376+ try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
377+ printer .println (
378+ "def sink(rdd: RDD[StructuredRecord], context:SparkExecutionPluginContext) : Unit = {" );
379+ printer .println (" val schema = context.getOutputSchema" );
380+ printer .println (" rdd" );
381+ printer .println (" .flatMap(_.get[String](\" body\" ).split(\" \\ \\ s+\" ))" );
382+ printer .println (" .map(s => (s, 1L))" );
383+ printer .println (" .reduceByKey(_ + _)" );
384+ printer .println (" .map(t => t._1 + \" \" + t._2)" );
385+ printer .println (" .saveAsTextFile(\" " + outputFolder .getAbsolutePath () + "\" )" );
386+ printer .println ("}" );
387+ }
388+ testWordCountSink (codeWriter .toString (), outputFolder );
389+ }
351390
352- File testFolder = TEMP_FOLDER .newFolder ("scalaSinkOutput" );
391+ @ Test
392+ public void testScalaSparkSinkDataFrame () throws Exception {
393+ File testFolder = TEMP_FOLDER .newFolder ("scalaSinkDataframeOutput" );
353394 File outputFolder = new File (testFolder , "output" );
354395 StringWriter codeWriter = new StringWriter ();
355396 try (PrintWriter printer = new PrintWriter (codeWriter , true )) {
@@ -363,24 +404,34 @@ public void testScalaSparkSink() throws Exception {
363404 printer .println (" out.write.format(\" text\" ).save(\" " + outputFolder .getAbsolutePath () + "\" )" );
364405 printer .println ("}" );
365406 }
407+ testWordCountSink (codeWriter .toString (), outputFolder );
408+ }
409+
410+ private void testWordCountSink (String code , File outputFolder ) throws Exception {
411+ Schema inputSchema = Schema .recordOf (
412+ "input" ,
413+ Schema .Field .of ("body" , Schema .nullableOf (Schema .of (Schema .Type .STRING )))
414+ );
415+
416+ String inputTable = UUID .randomUUID ().toString ();
366417
367418 // Pipeline configuration
368419 ETLBatchConfig etlConfig = ETLBatchConfig .builder ("* * * * *" )
369- .addStage (new ETLStage ("source" , MockSource .getPlugin ("sinkInput" , inputSchema )))
420+ .addStage (new ETLStage ("source" , MockSource .getPlugin (inputTable , inputSchema )))
370421 .addStage (new ETLStage ("sink" , new ETLPlugin ("ScalaSparkSink" , SparkSink .PLUGIN_TYPE ,
371- ImmutableMap .of ("scalaCode" , codeWriter . toString () ))))
422+ ImmutableMap .of ("scalaCode" , code ))))
372423 .addConnection ("source" , "sink" )
373424 .build ();
374425
375426 // Deploy the pipeline
376427 ArtifactSummary artifactSummary = new ArtifactSummary (DATAPIPELINE_ARTIFACT_ID .getArtifact (),
377428 DATAPIPELINE_ARTIFACT_ID .getVersion ());
378429 AppRequest <ETLBatchConfig > appRequest = new AppRequest <>(artifactSummary , etlConfig );
379- ApplicationId appId = NamespaceId .DEFAULT .app ("ScalaSparkSinkApp" );
430+ ApplicationId appId = NamespaceId .DEFAULT .app (UUID . randomUUID (). toString () );
380431 ApplicationManager appManager = deployApplication (appId , appRequest );
381432
382433 // write records to source
383- DataSetManager <Table > inputManager = getDataset (NamespaceId .DEFAULT .dataset ("sinkInput" ));
434+ DataSetManager <Table > inputManager = getDataset (NamespaceId .DEFAULT .dataset (inputTable ));
384435 List <StructuredRecord > inputRecords = new ArrayList <>();
385436 for (int i = 0 ; i < 10 ; i ++) {
386437 inputRecords .add (StructuredRecord .builder (inputSchema ).set ("body" , "Line " + i ).build ());
0 commit comments