Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit e58459f

Browse files
Cntd. backward chaining on Dataset.
1 parent c143b3b commit e58459f

1 file changed

Lines changed: 67 additions & 61 deletions

File tree

sansa-inference-spark/src/main/scala/net/sansa_stack/inference/spark/backwardchaining/BackwardChainingReasonerDataframe.scala

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package net.sansa_stack.inference.spark.backwardchaining
22

33

4+
import java.net.URI
5+
46
import org.apache.jena.graph.{Node, NodeFactory, Triple}
57
import org.apache.jena.reasoner.TriplePattern
68
import org.apache.jena.reasoner.rulesys.Rule
@@ -13,6 +15,7 @@ import net.sansa_stack.inference.rules.plan.SimpleSQLGenerator
1315
import net.sansa_stack.inference.spark.backwardchaining.BackwardChainingReasonerDataframe.time
1416
import net.sansa_stack.inference.spark.backwardchaining.tree.{AndNode, OrNode}
1517
import net.sansa_stack.inference.spark.data.loader.RDFGraphLoader
18+
import net.sansa_stack.inference.spark.utils.NTriplesToParquetConverter.{DEFAULT_NUM_THREADS, DEFAULT_PARALLELISM}
1619
import net.sansa_stack.inference.utils.RuleUtils._
1720
import net.sansa_stack.inference.utils.{Logging, TripleUtils}
1821

@@ -386,15 +389,20 @@ class BackwardChainingReasonerDataframe(
386389

387390
object BackwardChainingReasonerDataframe {
388391

392+
val DEFAULT_PARALLELISM = 200
393+
val DEFAULT_NUM_THREADS = 4
389394

390395
def main(args: Array[String]): Unit = {
396+
if (args.length == 0) sys.error("USAGE: BackwardChainingReasonerDataframe <INPUT_PATH>+ <NUM_THREADS>? <PARALLELISM>?")
391397

392-
val parallelism = 200
398+
val inputPath = args(0)
399+
val numThreads = if (args.length > 1) args(1).toInt else DEFAULT_NUM_THREADS
400+
val parallelism = if (args.length > 2) args(2).toInt else DEFAULT_PARALLELISM
393401

394402
// the SPARK config
395403
val session = SparkSession.builder
396404
.appName(s"Spark Backward Chaining")
397-
.master("local[4]")
405+
.master(s"local[$numThreads]")
398406
.config("spark.eventLog.enabled", "true")
399407
.config("spark.hadoop.validateOutputSpecs", "false") // override output files
400408
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
@@ -403,14 +411,12 @@ object BackwardChainingReasonerDataframe {
403411
.config("spark.sql.shuffle.partitions", parallelism)
404412
.config("spark.sql.autoBroadcastJoinThreshold", "10485760")
405413
.config("parquet.enable.summary-metadata", "false")
406-
.config("spark.local.dir", "/home/user/work/datasets/spark/tmp")
414+
// .config("spark.local.dir", "/home/user/work/datasets/spark/tmp")
407415
.getOrCreate()
408416

409417
import session.implicits._
410418
// implicit val myObjEncoder = org.apache.spark.sql.Encoders.kryo[RDFTriple]
411419

412-
val path = "/home/user/work/datasets/lubm/1000/univ-bench.nt"//args(0)
413-
414420
// val triples = RDFGraphLoader.loadFromDisk(session, path)
415421
// .triples.map(t => RDFTriple(t.getSubject.toString(), t.getPredicate.toString(), t.getObject.toString()))
416422
//// .triples.map(t => RDFTriple(FmtUtils.stringForNode(t.getSubject), FmtUtils.stringForNode(t.getPredicate), FmtUtils.stringForNode(t.getObject)))
@@ -419,7 +425,7 @@ object BackwardChainingReasonerDataframe {
419425
// val graph = session.createDataset(triples)//.cache()
420426
// graph.write.mode(SaveMode.Append).parquet(tableDir)
421427

422-
val graph = session.read.parquet(args(0)).as[RDFTriple].cache()
428+
val graph = session.read.parquet(inputPath).as[RDFTriple].cache()
423429
graph.createOrReplaceTempView("TRIPLES")
424430

425431
// compute size here to have it cached
@@ -445,61 +451,61 @@ object BackwardChainingReasonerDataframe {
445451
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person"))
446452
compare(tp, reasoner)
447453

448-
// // :s rdf:type VAR
449-
// tp = Triple.create(
450-
// NodeFactory.createURI("http://www.Department0.University0.edu/FullProfessor0"),
451-
// RDF.`type`.asNode(),
452-
// NodeFactory.createVariable("o"))
453-
// compare(tp, reasoner)
454-
//
455-
// // VAR :p VAR
456-
// tp = Triple.create(
457-
// NodeFactory.createVariable("s"),
458-
// NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#degreeFrom"),
459-
// NodeFactory.createVariable("o"))
460-
// compare(tp, reasoner)
461-
//
462-
// // :s :p VAR
463-
// tp = Triple.create(
464-
// NodeFactory.createURI("http://www.Department4.University3.edu/GraduateStudent40"),
465-
// NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#degreeFrom"),
466-
// NodeFactory.createVariable("o"))
467-
// compare(tp, reasoner)
468-
//
469-
// // VAR :p :o
470-
// tp = Triple.create(
471-
// NodeFactory.createVariable("s"),
472-
// NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#degreeFrom"),
473-
// NodeFactory.createURI("http://www.University801.edu"))
474-
// compare(tp, reasoner)
475-
//
476-
// // :s VAR :o
477-
// tp = Triple.create(
478-
// NodeFactory.createURI("http://www.Department4.University3.edu/GraduateStudent40"),
479-
// NodeFactory.createVariable("p"),
480-
// NodeFactory.createURI("http://www.University801.edu"))
481-
// compare(tp, reasoner)
482-
//
483-
// // :s VAR VAR where :s is a resource
484-
// tp = Triple.create(
485-
// NodeFactory.createURI("http://www.Department4.University3.edu/GraduateStudent40"),
486-
// NodeFactory.createVariable("p"),
487-
// NodeFactory.createVariable("o"))
488-
// compare(tp, reasoner)
489-
//
490-
// // :s VAR VAR where :s is a class
491-
// tp = Triple.create(
492-
// NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#Book"),
493-
// NodeFactory.createVariable("p"),
494-
// NodeFactory.createVariable("o"))
495-
// compare(tp, reasoner)
496-
//
497-
// // :s VAR VAR where :s is a property
498-
// tp = Triple.create(
499-
// NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#undergraduateDegreeFrom"),
500-
// NodeFactory.createVariable("p"),
501-
// NodeFactory.createVariable("o"))
502-
// compare(tp, reasoner)
454+
// :s rdf:type VAR
455+
tp = Triple.create(
456+
NodeFactory.createURI("http://www.Department0.University0.edu/FullProfessor0"),
457+
RDF.`type`.asNode(),
458+
NodeFactory.createVariable("o"))
459+
compare(tp, reasoner)
460+
461+
// VAR :p VAR
462+
tp = Triple.create(
463+
NodeFactory.createVariable("s"),
464+
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#degreeFrom"),
465+
NodeFactory.createVariable("o"))
466+
compare(tp, reasoner)
467+
468+
// :s :p VAR
469+
tp = Triple.create(
470+
NodeFactory.createURI("http://www.Department4.University3.edu/GraduateStudent40"),
471+
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#degreeFrom"),
472+
NodeFactory.createVariable("o"))
473+
compare(tp, reasoner)
474+
475+
// VAR :p :o
476+
tp = Triple.create(
477+
NodeFactory.createVariable("s"),
478+
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#degreeFrom"),
479+
NodeFactory.createURI("http://www.University801.edu"))
480+
compare(tp, reasoner)
481+
482+
// :s VAR :o
483+
tp = Triple.create(
484+
NodeFactory.createURI("http://www.Department4.University3.edu/GraduateStudent40"),
485+
NodeFactory.createVariable("p"),
486+
NodeFactory.createURI("http://www.University801.edu"))
487+
compare(tp, reasoner)
488+
489+
// :s VAR VAR where :s is a resource
490+
tp = Triple.create(
491+
NodeFactory.createURI("http://www.Department4.University3.edu/GraduateStudent40"),
492+
NodeFactory.createVariable("p"),
493+
NodeFactory.createVariable("o"))
494+
compare(tp, reasoner)
495+
496+
// :s VAR VAR where :s is a class
497+
tp = Triple.create(
498+
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#Book"),
499+
NodeFactory.createVariable("p"),
500+
NodeFactory.createVariable("o"))
501+
compare(tp, reasoner)
502+
503+
// :s VAR VAR where :s is a property
504+
tp = Triple.create(
505+
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#undergraduateDegreeFrom"),
506+
NodeFactory.createVariable("p"),
507+
NodeFactory.createVariable("o"))
508+
compare(tp, reasoner)
503509

504510
session.stop()
505511
}

0 commit comments

Comments
 (0)