Skip to content

Commit 61b81db

Browse files
committed
Mean weight is calculated as sum of weights over number of possible edges
1 parent 6074177 commit 61b81db

17 files changed

Lines changed: 97 additions & 35 deletions

File tree

src/main/java/edu/virginia/uvacluster/internal/SupervisedModel.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,13 @@ private CySubNetwork genNegativeExample(int size, List<CySubNetwork> positiveExa
285285
*/
286286
public void train(List<CySubNetwork> positiveExamples, List<CySubNetwork> negativeExamples) {
287287
System.out.println("Entered TRAIN");
288+
289+
for (FeatureSet feature : features) {
290+
for (String desc : feature.getDescriptions()) {
291+
System.out.println(desc);
292+
}
293+
}
294+
288295
List<Cluster> posExamples = new ArrayList<Cluster>(), negExamples = new ArrayList<Cluster>();
289296
for(CySubNetwork pos: positiveExamples) {posExamples.add(new Cluster(features, pos)); }
290297
for(CySubNetwork neg: negativeExamples) {negExamples.add(new Cluster(features, neg));}

src/main/java/edu/virginia/uvacluster/internal/feature/EdgeWeight.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
package edu.virginia.uvacluster.internal.feature;
22

3+
import java.util.ArrayList;
34
import java.util.List;
45

6+
import org.cytoscape.model.CyEdge;
7+
import org.cytoscape.model.CyTable;
8+
9+
import edu.virginia.uvacluster.internal.Cluster;
510
import edu.virginia.uvacluster.internal.statistic.Statistic;
611

712
public class EdgeWeight extends EdgeTableFeature{

src/main/java/edu/virginia/uvacluster/internal/feature/FeatureSet.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,22 @@ public FeatureSet(String description, List<Statistic> statistics) {
2424

2525
public List<Statistic> train(Cluster cluster) {
2626
for (Statistic statistic: statistics)
27-
statistic.train(computeInputs(cluster));
27+
statistic.train(computeInputs(cluster), cluster);
2828
// System.out.println("Training feature");
2929
return statistics;
3030
}
3131

3232
public List<Double> getValues(Cluster cluster) {
3333
List<Double> result = new ArrayList<Double>();
3434
for(Statistic statistic: statistics)
35-
result.add(statistic.transform(computeInputs(cluster)));
35+
result.add(statistic.transform(computeInputs(cluster), cluster));
3636
return result;
3737
}
3838

3939
public List<Integer> getBinnedValues(Cluster cluster) {
4040
List<Integer> result = new ArrayList<Integer>();
4141
for(Statistic statistic: statistics)
42-
result.add(statistic.binTransform(computeInputs(cluster)));
42+
result.add(statistic.binTransform(computeInputs(cluster), cluster));
4343
return result;
4444
}
4545

src/main/java/edu/virginia/uvacluster/internal/feature/FeatureUtil.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ public static List<FeatureSet> parse(Set<String> featureKeys) {
4848
m.matches();
4949
numBins = Integer.parseInt(m.group(2));
5050

51+
// Check if Mean : weight -- to use alternative mean calculation
52+
// over all possible edges
53+
if (weightFeaturePattern.matcher(featureName).matches())
54+
statName = "mean possible";
55+
5156
if (statsMap.get(featureName) == null)
5257
statsMap.put(featureName, new HashSet<Statistic>());
5358
statsMap.get(featureName).add(getStat(statName, numBins));
@@ -84,6 +89,9 @@ private static Statistic getStat(String name, int numBins) {
8489
case "mean":
8590
stat = new Mean(range);
8691
break;
92+
case "mean possible":
93+
stat = new MeanPossible(range);
94+
break;
8795
case "median":
8896
stat = new Median(range);
8997
break;

src/main/java/edu/virginia/uvacluster/internal/statistic/Count.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
import java.util.List;
44

5+
import edu.virginia.uvacluster.internal.Cluster;
6+
57
public class Count extends Statistic {
68

79
public Count(StatisticRange range) {
810
super(range,"count");
911
}
1012

1113
@Override
12-
public double transform(List<Double> values) {
14+
public double transform(List<Double> values, Cluster cluster) {
1315
return values.size();
1416
}
1517
}

src/main/java/edu/virginia/uvacluster/internal/statistic/Max.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import java.util.Collections;
55
import java.util.List;
66

7+
import edu.virginia.uvacluster.internal.Cluster;
8+
79
public class Max extends Statistic {
810

911
public Max(StatisticRange range) {
1012
super(range, "max");
1113
}
1214

13-
public double transform(List<Double> values) {
15+
public double transform(List<Double> values, Cluster cluster) {
1416
ArrayList<Double> copy = new ArrayList<Double>(values);
1517
double max = 0;
1618
int numElements = copy.size();

src/main/java/edu/virginia/uvacluster/internal/statistic/Mean.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
import java.util.List;
44

5+
import edu.virginia.uvacluster.internal.Cluster;
6+
57
public class Mean extends Statistic {
68

79
public Mean(StatisticRange range) {
810
super(range, "mean");
911
}
1012

11-
public double transform(List<Double> values) {
13+
public double transform(List<Double> values, Cluster cluster) {
1214
double mean = 0;
1315
double sum = 0;
1416
double numElements = values.size();
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package edu.virginia.uvacluster.internal.statistic;
2+
3+
import java.util.List;
4+
5+
import edu.virginia.uvacluster.internal.Cluster;
6+
7+
public class MeanPossible extends Statistic {
8+
9+
public MeanPossible(StatisticRange range) {
10+
super(range, "mean");
11+
}
12+
13+
public double transform(List<Double> values, Cluster cluster) {
14+
double sum = 0;
15+
int clusterSize = cluster.size();
16+
double possibleEdges = clusterSize * (clusterSize - 1) / 2;
17+
18+
for (Double value: values)
19+
{
20+
sum += value;
21+
}
22+
23+
return sum / possibleEdges;
24+
}
25+
26+
}

src/main/java/edu/virginia/uvacluster/internal/statistic/Median.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import java.util.Collections;
55
import java.util.List;
66

7+
import edu.virginia.uvacluster.internal.Cluster;
8+
79
public class Median extends Statistic {
810

911
public Median(StatisticRange range) {
1012
super(range, "median");
1113
}
1214

13-
public double transform(List<Double> values) {
15+
public double transform(List<Double> values, Cluster cluster) {
1416
ArrayList<Double> copy = new ArrayList<Double>(values);
1517
double median = 0;
1618
int numElements = copy.size();

src/main/java/edu/virginia/uvacluster/internal/statistic/Ordinal.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import java.util.List;
44

5+
import edu.virginia.uvacluster.internal.Cluster;
6+
57
public class Ordinal extends Statistic {
68
private Integer ordinalNum;
79

@@ -11,7 +13,7 @@ public Ordinal(StatisticRange range, int ordinalNum) {
1113
}
1214

1315
@Override
14-
public double transform(List<Double> values) {
16+
public double transform(List<Double> values, Cluster cluster) {
1517
return values.get(getIndex());
1618
}
1719

0 commit comments

Comments
 (0)